midah commited on
Commit
4fac556
Β·
1 Parent(s): e904fd3

Apply clean grayscale design, remove all emojis

Browse files

Remove gradient backgrounds, purple theme, and emojis throughout UI

This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. backend/README.md +0 -24
  2. backend/api/dependencies.py +23 -0
  3. backend/api/main.py +764 -427
  4. backend/api/routes/__init__.py +6 -0
  5. backend/api/routes/clusters.py +102 -0
  6. backend/api/routes/models.py +247 -0
  7. backend/api/routes/stats.py +37 -0
  8. backend/config/requirements.txt +1 -0
  9. backend/core/__init__.py +2 -0
  10. backend/core/config.py +23 -0
  11. backend/core/exceptions.py +18 -0
  12. backend/models/__init__.py +2 -0
  13. backend/models/schemas.py +22 -0
  14. backend/scripts/export_binary.py +263 -0
  15. backend/services/model_tracker.py +83 -24
  16. backend/services/model_tracker_improved.py +95 -30
  17. backend/utils/data_loader.py +1 -4
  18. backend/utils/embeddings.py +1 -1
  19. backend/utils/family_tree.py +66 -0
  20. backend/utils/graph_embeddings.py +177 -0
  21. backend/utils/network_analysis.py +163 -20
  22. frontend/.npmrc +2 -0
  23. frontend/package-lock.json +2 -1
  24. frontend/package.json +2 -1
  25. frontend/public/index.html +1 -1
  26. frontend/src/App.css +85 -202
  27. frontend/src/App.tsx +49 -118
  28. frontend/src/components/PaperPlots.css +0 -92
  29. frontend/src/components/PaperPlots.tsx +0 -755
  30. frontend/src/components/ScatterPlot.tsx +0 -7
  31. frontend/src/components/controls/ClusterFilter.css +122 -0
  32. frontend/src/components/controls/ClusterFilter.tsx +142 -0
  33. frontend/src/components/controls/NodeDensitySlider.css +31 -0
  34. frontend/src/components/controls/NodeDensitySlider.tsx +39 -0
  35. frontend/src/components/controls/RandomModelButton.tsx +32 -0
  36. frontend/src/components/controls/RenderingStyleSelector.css +37 -0
  37. frontend/src/components/controls/RenderingStyleSelector.tsx +43 -0
  38. frontend/src/components/controls/ThemeToggle.tsx +22 -0
  39. frontend/src/components/controls/VisualizationModeButtons.css +65 -0
  40. frontend/src/components/controls/VisualizationModeButtons.tsx +46 -0
  41. frontend/src/components/controls/ZoomSlider.tsx +43 -0
  42. frontend/src/components/layout/SearchBar.css +181 -0
  43. frontend/src/components/layout/SearchBar.tsx +201 -0
  44. frontend/src/components/{FileTree.css β†’ modals/FileTree.css} +171 -3
  45. frontend/src/components/{FileTree.tsx β†’ modals/FileTree.tsx} +314 -26
  46. frontend/src/components/{ModelModal.css β†’ modals/ModelModal.css} +43 -14
  47. frontend/src/components/{ModelModal.tsx β†’ modals/ModelModal.tsx} +17 -9
  48. frontend/src/components/{ColorLegend.css β†’ ui/ColorLegend.css} +0 -0
  49. frontend/src/components/{ColorLegend.tsx β†’ ui/ColorLegend.tsx} +1 -1
  50. frontend/src/components/{ErrorBoundary.tsx β†’ ui/ErrorBoundary.tsx} +0 -0
backend/README.md DELETED
@@ -1,24 +0,0 @@
1
- # Backend API
2
-
3
- FastAPI backend for serving model data to the React frontend.
4
-
5
- ## Structure
6
-
7
- - `api/` - API routes and main application
8
- - `services/` - External service integrations (arXiv, model tracking, scheduling)
9
- - `utils/` - Utility modules (data loading, embeddings, dimensionality reduction, clustering, network analysis)
10
- - `config/` - Configuration files (requirements.txt, etc.)
11
- - `cache/` - Cached data (embeddings, reduced dimensions)
12
-
13
- ## Running
14
-
15
- ```bash
16
- cd backend
17
- uvicorn api.main:app --reload --host 0.0.0.0 --port 8000
18
- ```
19
-
20
- ## Environment Variables
21
-
22
- - `SAMPLE_SIZE` - Limit number of models to load (for development). Set to 0 or leave unset to load all models.
23
-
24
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/api/dependencies.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared dependencies for API routes."""
2
+ import pandas as pd
3
+ import numpy as np
4
+ from typing import Optional, Dict
5
+ from utils.data_loader import ModelDataLoader
6
+ from utils.embeddings import ModelEmbedder
7
+ from utils.dimensionality_reduction import DimensionReducer
8
+ from utils.graph_embeddings import GraphEmbedder
9
+
10
+ # Global state (initialized in startup) - these are module-level variables
11
+ # that will be updated by main.py during startup
12
+ data_loader = ModelDataLoader()
13
+ embedder: Optional[ModelEmbedder] = None
14
+ graph_embedder: Optional[GraphEmbedder] = None
15
+ reducer: Optional[DimensionReducer] = None
16
+ df: Optional[pd.DataFrame] = None
17
+ embeddings: Optional[np.ndarray] = None
18
+ graph_embeddings_dict: Optional[Dict[str, np.ndarray]] = None
19
+ combined_embeddings: Optional[np.ndarray] = None
20
+ reduced_embeddings: Optional[np.ndarray] = None
21
+ reduced_embeddings_graph: Optional[np.ndarray] = None
22
+ cluster_labels: Optional[np.ndarray] = None
23
+
backend/api/main.py CHANGED
@@ -1,202 +1,216 @@
1
- """
2
- FastAPI backend for serving model data to React/Visx frontend.
3
- """
4
  import sys
5
  import os
6
- backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
- if backend_dir not in sys.path:
8
- sys.path.insert(0, backend_dir)
 
 
9
 
 
 
 
10
  from fastapi import FastAPI, HTTPException, Query, BackgroundTasks, Request
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.middleware.gzip import GZipMiddleware
13
  from fastapi.responses import FileResponse, JSONResponse
14
  from fastapi.exceptions import RequestValidationError
15
  from starlette.exceptions import HTTPException as StarletteHTTPException
16
- from typing import Optional, List, Dict
17
- import pandas as pd
18
- import numpy as np
19
  from pydantic import BaseModel
20
  from umap import UMAP
21
- import tempfile
22
- import traceback
23
- import httpx
24
 
25
  from utils.data_loader import ModelDataLoader
26
  from utils.embeddings import ModelEmbedder
27
  from utils.dimensionality_reduction import DimensionReducer
28
  from utils.network_analysis import ModelNetworkBuilder
 
29
  from services.model_tracker import get_tracker
30
- from services.model_tracker_improved import get_improved_tracker
31
  from services.arxiv_api import extract_arxiv_ids, fetch_arxiv_papers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- app = FastAPI(title="HF Model Ecosystem API")
 
 
34
 
35
  app.add_middleware(GZipMiddleware, minimum_size=1000)
36
 
 
 
 
 
 
 
37
  @app.exception_handler(Exception)
38
  async def global_exception_handler(request: Request, exc: Exception):
39
- """Global exception handler that ensures CORS headers are included even on errors."""
40
- import traceback
41
- error_detail = str(exc)
42
- traceback_str = traceback.format_exc()
43
- import sys
44
- sys.stderr.write(f"Unhandled exception: {error_detail}\n{traceback_str}\n")
45
  return JSONResponse(
46
  status_code=500,
47
- content={"detail": error_detail, "error": "Internal server error"},
48
- headers={
49
- "Access-Control-Allow-Origin": "*",
50
- "Access-Control-Allow-Methods": "*",
51
- "Access-Control-Allow-Headers": "*",
52
- }
53
  )
54
 
55
  @app.exception_handler(StarletteHTTPException)
56
  async def http_exception_handler(request: Request, exc: StarletteHTTPException):
57
- """HTTP exception handler with CORS headers."""
58
  return JSONResponse(
59
  status_code=exc.status_code,
60
  content={"detail": exc.detail},
61
- headers={
62
- "Access-Control-Allow-Origin": "*",
63
- "Access-Control-Allow-Methods": "*",
64
- "Access-Control-Allow-Headers": "*",
65
- }
66
  )
67
 
68
  @app.exception_handler(RequestValidationError)
69
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
70
- """Validation exception handler with CORS headers."""
71
  return JSONResponse(
72
  status_code=422,
73
  content={"detail": exc.errors()},
74
- headers={
75
- "Access-Control-Allow-Origin": "*",
76
- "Access-Control-Allow-Methods": "*",
77
- "Access-Control-Allow-Headers": "*",
78
- }
79
  )
80
 
81
- # CORS middleware for React frontend
82
- # Update allow_origins with your Netlify URL in production
83
- # Note: Add your specific Netlify URL after deployment
84
- FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:3000")
85
- # Allow all origins for development (restrict in production)
86
- ALLOW_ALL_ORIGINS = os.getenv("ALLOW_ALL_ORIGINS", "true").lower() == "true"
87
- if ALLOW_ALL_ORIGINS:
88
  app.add_middleware(
89
  CORSMiddleware,
90
- allow_origins=["*"], # Allow all origins in development
91
- allow_credentials=False, # Must be False when allow_origins is ["*"]
92
  allow_methods=["*"],
93
  allow_headers=["*"],
94
  )
95
  else:
96
  app.add_middleware(
97
  CORSMiddleware,
98
- allow_origins=[
99
- "http://localhost:3000", # Local development
100
- FRONTEND_URL, # Production frontend URL
101
- # Add your Netlify URL here after deployment, e.g.:
102
- # "https://your-app-name.netlify.app",
103
- ],
104
  allow_credentials=True,
105
  allow_methods=["*"],
106
  allow_headers=["*"],
107
  )
108
 
109
- data_loader = ModelDataLoader()
110
- embedder: Optional[ModelEmbedder] = None
111
- reducer: Optional[DimensionReducer] = None
112
- df: Optional[pd.DataFrame] = None
113
- embeddings: Optional[np.ndarray] = None
114
- reduced_embeddings: Optional[np.ndarray] = None
115
- cluster_labels: Optional[np.ndarray] = None # Cached cluster assignments
116
-
117
-
118
- class FilterParams(BaseModel):
119
- min_downloads: int = 0
120
- min_likes: int = 0
121
- search_query: Optional[str] = None
122
- libraries: Optional[List[str]] = None
123
- pipeline_tags: Optional[List[str]] = None
124
-
125
-
126
- class ModelPoint(BaseModel):
127
- model_id: str
128
- x: float
129
- y: float
130
- z: float # 3D coordinate
131
- library_name: Optional[str]
132
- pipeline_tag: Optional[str]
133
- downloads: int
134
- likes: int
135
- trending_score: Optional[float]
136
- tags: Optional[str]
137
- parent_model: Optional[str] = None
138
- licenses: Optional[str] = None
139
- family_depth: Optional[int] = None # Generation depth in family tree (0 = root)
140
- cluster_id: Optional[int] = None # Cluster assignment for visualization
141
 
 
 
 
 
142
 
143
  @app.on_event("startup")
144
  async def startup_event():
145
- """Initialize data and models on startup with caching."""
146
- global df, embedder, reducer, embeddings, reduced_embeddings
147
 
148
- import os
149
  backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
150
  root_dir = os.path.dirname(backend_dir)
151
  cache_dir = os.path.join(root_dir, "cache")
152
  os.makedirs(cache_dir, exist_ok=True)
153
 
154
  embeddings_cache = os.path.join(cache_dir, "embeddings.pkl")
 
 
155
  reduced_cache_umap = os.path.join(cache_dir, "reduced_umap_3d.pkl")
 
156
  reducer_cache_umap = os.path.join(cache_dir, "reducer_umap_3d.pkl")
 
157
 
158
- sample_size_env = os.getenv("SAMPLE_SIZE")
159
- if sample_size_env is None:
160
- sample_size = None
161
  else:
162
- sample_size = int(sample_size_env)
163
- if sample_size == 0:
164
- sample_size = None
165
- df = data_loader.load_data(sample_size=sample_size)
166
- df = data_loader.preprocess_for_embedding(df)
167
-
168
- if 'model_id' in df.columns:
169
- df.set_index('model_id', drop=False, inplace=True)
170
  for col in ['downloads', 'likes']:
171
- if col in df.columns:
172
- df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0).astype(int)
173
 
174
- embedder = ModelEmbedder()
175
 
 
176
  if os.path.exists(embeddings_cache):
177
  try:
178
- embeddings = embedder.load_embeddings(embeddings_cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  except Exception as e:
180
- embeddings = None
181
-
182
- if embeddings is None:
183
- texts = df['combined_text'].tolist()
184
- embeddings = embedder.generate_embeddings(texts, batch_size=128)
185
- embedder.save_embeddings(embeddings, embeddings_cache)
186
 
187
- reducer = DimensionReducer(method="umap", n_components=3)
 
188
 
189
  if os.path.exists(reduced_cache_umap) and os.path.exists(reducer_cache_umap):
190
  try:
191
- import pickle
192
  with open(reduced_cache_umap, 'rb') as f:
193
- reduced_embeddings = pickle.load(f)
194
- reducer.load_reducer(reducer_cache_umap)
195
- except Exception as e:
196
- reduced_embeddings = None
197
-
198
- if reduced_embeddings is None:
199
- reducer.reducer = UMAP(
 
200
  n_components=3,
201
  n_neighbors=30,
202
  min_dist=0.3,
@@ -206,61 +220,57 @@ async def startup_event():
206
  low_memory=True,
207
  spread=1.5
208
  )
209
- reduced_embeddings = reducer.fit_transform(embeddings)
210
- import pickle
211
  with open(reduced_cache_umap, 'wb') as f:
212
- pickle.dump(reduced_embeddings, f)
213
- reducer.save_reducer(reducer_cache_umap)
214
-
215
-
216
- def calculate_family_depths(df: pd.DataFrame) -> Dict[str, int]:
217
- """
218
- Calculate family tree depth for each model.
219
- Returns a dictionary mapping model_id to depth (0 = root, 1 = first generation, etc.)
220
- """
221
- depths = {}
222
- visited = set()
223
 
224
- def get_depth(model_id: str) -> int:
225
- if model_id in depths:
226
- return depths[model_id]
227
- if model_id in visited:
228
- # Circular reference, treat as root
229
- depths[model_id] = 0
230
- return 0
231
 
232
- visited.add(model_id)
233
-
234
- if model_id not in df.index:
235
- depths[model_id] = 0
236
- return 0
237
-
238
- parent_id = df.loc[model_id].get('parent_model')
239
- if parent_id and pd.notna(parent_id) and str(parent_id) != 'nan' and str(parent_id) != '':
240
- parent_id_str = str(parent_id)
241
- if parent_id_str in df.index:
242
- depth = get_depth(parent_id_str) + 1
243
- else:
244
- depth = 0 # Parent not in dataset, treat as root
245
- else:
246
- depth = 0 # No parent, this is a root
247
 
248
- depths[model_id] = depth
249
- return depth
250
-
251
- for model_id in df.index:
252
- if model_id not in depths:
253
- visited = set() # Reset for each tree
254
- get_depth(model_id)
 
 
 
 
 
 
 
 
 
255
 
256
- return depths
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
 
259
  def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray:
260
- """
261
- Compute clusters using KMeans on reduced embeddings.
262
- Returns cluster labels for each point.
263
- """
264
  from sklearn.cluster import KMeans
265
 
266
  n_samples = len(reduced_embeddings)
@@ -268,8 +278,7 @@ def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np
268
  n_clusters = max(1, n_samples // 10)
269
 
270
  kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
271
- cluster_labels = kmeans.fit_predict(reduced_embeddings)
272
- return cluster_labels
273
 
274
 
275
  @app.get("/")
@@ -284,24 +293,16 @@ async def get_models(
284
  search_query: Optional[str] = Query(None),
285
  color_by: str = Query("library_name"),
286
  size_by: str = Query("downloads"),
287
- max_points: Optional[int] = Query(None), # Optional limit (None = all points)
288
- projection_method: str = Query("umap"), # umap or tsne
289
- base_models_only: bool = Query(False) # Only show root models (no parent)
 
 
290
  ):
291
- """
292
- Get filtered models with 3D coordinates for visualization.
293
- Supports multiple projection methods: UMAP or t-SNE.
294
- If base_models_only=True, only returns root models (models without a parent_model).
295
-
296
- Returns a JSON object with:
297
- - models: List of ModelPoint objects
298
- - filtered_count: Number of models matching filters (before max_points sampling)
299
- - returned_count: Number of models actually returned (after max_points sampling)
300
- """
301
- global df, embedder, reducer, embeddings, reduced_embeddings
302
 
303
- if df is None:
304
- raise HTTPException(status_code=503, detail="Data not loaded")
305
 
306
  # Filter data
307
  filtered_df = data_loader.filter_data(
@@ -321,7 +322,12 @@ async def get_models(
321
  (filtered_df['parent_model'].astype(str) == 'nan')
322
  ]
323
 
324
- # Store the filtered count BEFORE sampling
 
 
 
 
 
325
  filtered_count = len(filtered_df)
326
 
327
  if len(filtered_df) == 0:
@@ -332,42 +338,53 @@ async def get_models(
332
  }
333
 
334
  if max_points is not None and len(filtered_df) > max_points:
335
- # Use stratified sampling to preserve distribution of important attributes
336
- # Sample proportionally from different libraries/pipelines for better representation
337
  if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
338
- # Stratified sampling by library
339
- filtered_df = filtered_df.groupby('library_name', group_keys=False).apply(
340
- lambda x: x.sample(min(len(x), max(1, int(max_points * len(x) / len(filtered_df)))), random_state=42)
341
- ).reset_index(drop=True)
342
- # If still too many, random sample the rest
 
343
  if len(filtered_df) > max_points:
344
- filtered_df = filtered_df.sample(n=max_points, random_state=42)
 
 
345
  else:
346
- filtered_df = filtered_df.sample(n=max_points, random_state=42)
347
-
348
- if embeddings is None:
349
- raise HTTPException(status_code=503, detail="Embeddings not loaded")
350
 
351
- if reduced_embeddings is None or (reducer and reducer.method != projection_method.lower()):
352
- import os
 
 
 
 
 
 
 
 
 
 
 
 
353
  backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
354
  root_dir = os.path.dirname(backend_dir)
355
  cache_dir = os.path.join(root_dir, "cache")
356
- reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d.pkl")
357
- reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d.pkl")
 
358
 
359
  if os.path.exists(reduced_cache) and os.path.exists(reducer_cache):
360
  try:
361
- import pickle
362
  with open(reduced_cache, 'rb') as f:
363
- reduced_embeddings = pickle.load(f)
364
  if reducer is None or reducer.method != projection_method.lower():
365
  reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
366
  reducer.load_reducer(reducer_cache)
367
- except Exception as e:
368
- reduced_embeddings = None
 
369
 
370
- if reduced_embeddings is None:
371
  if reducer is None or reducer.method != projection_method.lower():
372
  reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
373
  if projection_method.lower() == "umap":
@@ -381,52 +398,91 @@ async def get_models(
381
  low_memory=True,
382
  spread=1.5
383
  )
384
- reduced_embeddings = reducer.fit_transform(embeddings)
385
- import pickle
386
  with open(reduced_cache, 'wb') as f:
387
- pickle.dump(reduced_embeddings, f)
388
  reducer.save_reducer(reducer_cache)
 
 
 
 
 
 
 
 
 
 
 
389
 
390
- # Get coordinates for filtered data - optimized vectorized approach
391
- # Map filtered dataframe indices to original dataframe integer positions
392
- # Since df is indexed by model_id, we need to get the integer positions
393
  if df.index.name == 'model_id' or 'model_id' in df.index.names:
394
- # Get integer positions of filtered rows in original dataframe
395
- # Use vectorized lookup for better performance
396
- filtered_indices = np.array([df.index.get_loc(idx) for idx in filtered_df.index], dtype=np.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  else:
398
- # If using integer index, use directly
399
- filtered_indices = filtered_df.index.values.astype(np.int32)
400
-
401
- # Use advanced indexing for faster access
402
- filtered_reduced = reduced_embeddings[filtered_indices]
 
 
 
 
 
 
 
 
 
 
403
 
 
404
  family_depths = calculate_family_depths(df)
405
 
406
- global cluster_labels
407
- if cluster_labels is None or len(cluster_labels) != len(reduced_embeddings):
408
- cluster_labels = compute_clusters(reduced_embeddings, n_clusters=min(50, len(reduced_embeddings) // 100))
 
 
409
 
410
- filtered_clusters = cluster_labels[filtered_indices]
 
 
 
 
 
 
 
 
411
 
412
- # Build response with optimized vectorized operations
413
- # Pre-extract arrays for faster access
414
  model_ids = filtered_df['model_id'].astype(str).values
415
- library_names = filtered_df['library_name'].values
416
- pipeline_tags = filtered_df['pipeline_tag'].values
417
- downloads_arr = filtered_df['downloads'].fillna(0).astype(int).values
418
- likes_arr = filtered_df['likes'].fillna(0).astype(int).values
419
- trending_scores = filtered_df.get('trendingScore', pd.Series()).values
420
- tags_arr = filtered_df.get('tags', pd.Series()).values
421
- parent_models = filtered_df.get('parent_model', pd.Series()).values
422
- licenses_arr = filtered_df.get('licenses', pd.Series()).values
423
-
424
- # Vectorized coordinate extraction
425
  x_coords = filtered_reduced[:, 0].astype(float)
426
  y_coords = filtered_reduced[:, 1].astype(float)
427
  z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float)
428
-
429
- # Build models list with optimized operations
430
  models = [
431
  ModelPoint(
432
  model_id=model_ids[idx],
@@ -442,28 +498,42 @@ async def get_models(
442
  parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None,
443
  licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None,
444
  family_depth=family_depths.get(model_ids[idx], None),
445
- cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None
 
446
  )
447
  for idx in range(len(filtered_df))
448
  ]
449
 
450
- return models
 
 
 
 
 
 
451
 
452
 
453
  @app.get("/api/stats")
454
  async def get_stats():
455
  """Get dataset statistics."""
456
  if df is None:
457
- raise HTTPException(status_code=503, detail="Data not loaded")
458
 
459
- # Use len(df.index) to handle both regular and indexed DataFrames correctly
460
  total_models = len(df.index) if hasattr(df, 'index') else len(df)
461
 
 
 
 
 
 
 
462
  return {
463
  "total_models": total_models,
464
  "unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0,
465
  "unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
466
  "unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0, # Alias for clarity
 
 
467
  "avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0,
468
  "avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0
469
  }
@@ -473,7 +543,7 @@ async def get_stats():
473
  async def get_model_details(model_id: str):
474
  """Get detailed information about a specific model."""
475
  if df is None:
476
- raise HTTPException(status_code=503, detail="Data not loaded")
477
 
478
  model = df[df.get('model_id', '') == model_id]
479
  if len(model) == 0:
@@ -481,11 +551,9 @@ async def get_model_details(model_id: str):
481
 
482
  model = model.iloc[0]
483
 
484
- # Extract arXiv IDs from tags
485
  tags_str = str(model.get('tags', '')) if pd.notna(model.get('tags')) else ''
486
  arxiv_ids = extract_arxiv_ids(tags_str)
487
 
488
- # Fetch arXiv papers if any IDs found
489
  papers = []
490
  if arxiv_ids:
491
  papers = await fetch_arxiv_papers(arxiv_ids[:5]) # Limit to 5 papers
@@ -505,6 +573,8 @@ async def get_model_details(model_id: str):
505
  }
506
 
507
 
 
 
508
  @app.get("/api/family/stats")
509
  async def get_family_stats():
510
  """
@@ -512,9 +582,8 @@ async def get_family_stats():
512
  Returns family size distribution, depth statistics, model card length by depth, etc.
513
  """
514
  if df is None:
515
- raise HTTPException(status_code=503, detail="Data not loaded")
516
 
517
- # Calculate family sizes
518
  family_sizes = {}
519
  root_models = set()
520
 
@@ -528,14 +597,13 @@ async def get_family_stats():
528
  family_sizes[model_id] = 0
529
  else:
530
  parent_id_str = str(parent_id)
531
- # Find root of this family
532
  root = parent_id_str
533
  visited = set()
534
  while root in df.index and pd.notna(df.loc[root].get('parent_model')):
535
  parent = df.loc[root].get('parent_model')
536
  if pd.isna(parent) or str(parent) == 'nan' or str(parent) == '':
537
  break
538
- if str(parent) in visited: # Circular reference
539
  break
540
  visited.add(root)
541
  root = str(parent)
@@ -544,18 +612,15 @@ async def get_family_stats():
544
  family_sizes[root] = 0
545
  family_sizes[root] += 1
546
 
547
- # Count family sizes
548
  size_distribution = {}
549
  for root, size in family_sizes.items():
550
  size_distribution[size] = size_distribution.get(size, 0) + 1
551
 
552
- # Calculate depth statistics
553
  depths = calculate_family_depths(df)
554
  depth_counts = {}
555
  for depth in depths.values():
556
  depth_counts[depth] = depth_counts.get(depth, 0) + 1
557
 
558
- # Calculate model card length by depth
559
  model_card_lengths_by_depth = {}
560
  if 'modelCard' in df.columns:
561
  for idx, row in df.iterrows():
@@ -568,7 +633,6 @@ async def get_family_stats():
568
  model_card_lengths_by_depth[depth] = []
569
  model_card_lengths_by_depth[depth].append(card_length)
570
 
571
- # Calculate statistics for each depth
572
  model_card_stats = {}
573
  for depth, lengths in model_card_lengths_by_depth.items():
574
  if lengths:
@@ -593,99 +657,218 @@ async def get_family_stats():
593
  }
594
 
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  @app.get("/api/family/{model_id}")
597
- async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10)):
 
 
 
 
598
  """
599
  Get family tree for a model (ancestors and descendants).
600
  Returns the model, its parent chain, and all children.
 
 
601
  """
602
  if df is None:
603
- raise HTTPException(status_code=503, detail="Data not loaded")
604
-
605
- # Find the model
606
- model_row = df[df.get('model_id', '') == model_id]
607
- if len(model_row) == 0:
608
- raise HTTPException(status_code=404, detail="Model not found")
609
-
610
- family_models = []
611
- visited = set()
612
 
613
- # Get coordinates for family members
614
  if reduced_embeddings is None:
615
  raise HTTPException(status_code=503, detail="Embeddings not ready")
616
 
617
- # Optimize: create parent_model index for faster lookups
618
- if 'parent_model' not in df.index.names and 'parent_model' in df.columns:
619
- # Create a reverse index for faster parent lookups
620
- parent_index = df[df['parent_model'].notna()].set_index('parent_model', drop=False, append=True)
621
 
622
- def get_ancestors(current_id: str, depth: int):
623
- """Recursively get parent chain - optimized with index lookup."""
624
- if depth <= 0 or current_id in visited:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  return
626
  visited.add(current_id)
627
 
628
- # Use index lookup if available, otherwise fallback to query
629
- if 'model_id' in df.index.names or df.index.name == 'model_id':
630
- try:
631
- model = df.loc[[current_id]]
632
- except KeyError:
633
- return
634
- else:
635
- model = df[df.get('model_id', '') == current_id]
636
- if len(model) == 0:
637
- return
638
- model = model.iloc[[0]]
639
-
640
- parent_id = model.iloc[0].get('parent_model')
641
-
642
- if parent_id and pd.notna(parent_id) and str(parent_id) != 'nan':
643
- get_ancestors(str(parent_id), depth - 1)
 
644
 
645
- def get_descendants(current_id: str, depth: int):
646
- """Recursively get all children - optimized with index lookup."""
647
- if depth <= 0 or current_id in visited:
 
648
  return
649
  visited.add(current_id)
650
 
651
- # Use optimized parent lookup
652
- if 'parent_model' in df.columns:
653
- children = df[df['parent_model'] == current_id]
654
- # Use vectorized iteration
655
- child_ids = children['model_id'].dropna().astype(str).unique()
656
- for child_id in child_ids:
657
- if child_id not in visited:
658
- get_descendants(child_id, depth - 1)
659
-
660
- # Get ancestors (parents)
661
- get_ancestors(model_id, max_depth)
662
-
663
- # Get descendants (children)
664
- visited = set() # Reset for descendants
665
- get_descendants(model_id, max_depth)
666
-
667
- # Add the root model
668
- visited.add(model_id)
669
-
670
- # Get all family members with coordinates - optimized
671
- if 'model_id' in df.index.names or df.index.name == 'model_id':
672
- # Use index lookup if available
673
  try:
674
  family_df = df.loc[list(visited)]
675
  except KeyError:
676
- # Fallback to isin if some IDs not in index
677
- family_df = df[df.get('model_id', '').isin(visited)]
 
 
678
  else:
679
  family_df = df[df.get('model_id', '').isin(visited)]
680
 
681
- family_indices = family_df.index.values # Use values instead of tolist() for speed
 
 
 
 
 
 
682
  family_reduced = reduced_embeddings[family_indices]
683
 
684
- # Build family tree structure - optimized with vectorized operations
685
  family_map = {}
686
  for idx, (i, row) in enumerate(family_df.iterrows()):
687
- model_id_val = str(row.get('model_id', 'Unknown'))
688
- parent_id = row.get('parent_model') if pd.notna(row.get('parent_model')) else None
 
 
 
 
 
 
 
689
 
690
  family_map[model_id_val] = {
691
  "model_id": model_id_val,
@@ -696,12 +879,12 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
696
  "pipeline_tag": str(row.get('pipeline_tag')) if pd.notna(row.get('pipeline_tag')) else None,
697
  "downloads": int(row.get('downloads', 0)) if pd.notna(row.get('downloads')) else 0,
698
  "likes": int(row.get('likes', 0)) if pd.notna(row.get('likes')) else 0,
699
- "parent_model": str(parent_id) if parent_id else None,
700
  "licenses": str(row.get('licenses')) if pd.notna(row.get('licenses')) else None,
 
701
  "children": []
702
  }
703
 
704
- # Build tree structure
705
  root_models = []
706
  for model_id_val, model_data in family_map.items():
707
  parent_id = model_data["parent_model"]
@@ -711,7 +894,7 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
711
  root_models.append(model_id_val)
712
 
713
  return {
714
- "root_model": model_id,
715
  "family": list(family_map.values()),
716
  "family_map": family_map,
717
  "root_models": root_models
@@ -720,7 +903,9 @@ async def get_family_tree(model_id: str, max_depth: int = Query(5, ge=1, le=10))
720
 
721
  @app.get("/api/search")
722
  async def search_models(
723
- query: str = Query(..., min_length=1),
 
 
724
  graph_aware: bool = Query(False),
725
  include_neighbors: bool = Query(True)
726
  ):
@@ -729,47 +914,79 @@ async def search_models(
729
  Enhanced with graph-aware search option that includes network relationships.
730
  """
731
  if df is None:
732
- raise HTTPException(status_code=503, detail="Data not loaded")
 
 
 
733
 
734
  if graph_aware:
735
- # Use graph-aware search
736
  try:
737
  network_builder = ModelNetworkBuilder(df)
738
- # Build network for top models (for performance)
739
  top_models = network_builder.get_top_models_by_field(n=1000)
740
  model_ids = [mid for mid, _ in top_models]
741
  graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
742
 
743
  results = network_builder.search_graph_aware(
744
- query=query,
745
  graph=graph,
746
- max_results=20,
747
  include_neighbors=include_neighbors
748
  )
749
 
750
- return {"results": results, "search_type": "graph_aware"}
751
- except Exception as e:
752
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
 
754
- query_lower = query.lower()
755
- matches = df[
756
- df.get('model_id', '').astype(str).str.lower().str.contains(query_lower, na=False)
757
- ].head(20) # Limit to 20 results
758
 
759
  results = []
760
  for _, row in matches.iterrows():
 
 
 
 
 
 
 
 
761
  results.append({
762
- "model_id": row.get('model_id'),
763
- "title": row.get('model_id', '').split('/')[-1] if '/' in str(row.get('model_id', '')) else str(row.get('model_id', '')),
764
- "library_name": row.get('library_name'),
765
- "pipeline_tag": row.get('pipeline_tag'),
 
 
 
 
766
  "downloads": int(row.get('downloads', 0)),
767
  "likes": int(row.get('likes', 0)),
768
  "parent_model": row.get('parent_model') if pd.notna(row.get('parent_model')) else None,
769
  "match_type": "direct"
770
  })
771
 
772
- return {"results": results, "search_type": "basic"}
773
 
774
 
775
  @app.get("/api/similar/{model_id}")
@@ -778,12 +995,12 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
778
  Get k-nearest neighbors of a model based on embedding similarity.
779
  Returns similar models with distance scores.
780
  """
781
- global df, embedder, embeddings, reduced_embeddings
782
-
783
- if df is None or embeddings is None:
784
  raise HTTPException(status_code=503, detail="Data not loaded")
785
 
786
- # Find the model - optimized with index lookup
 
 
787
  if 'model_id' in df.index.names or df.index.name == 'model_id':
788
  try:
789
  model_row = df.loc[[model_id]]
@@ -797,16 +1014,11 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
797
  model_idx = model_row.index[0]
798
  model_embedding = embeddings[model_idx]
799
 
800
- # Calculate cosine similarity to all other models - optimized
801
  from sklearn.metrics.pairwise import cosine_similarity
802
- # Use vectorized operations for better performance
803
  model_embedding_2d = model_embedding.reshape(1, -1)
804
  similarities = cosine_similarity(model_embedding_2d, embeddings)[0]
805
 
806
- # Get top k similar models (excluding itself) - use argpartition for speed
807
- # argpartition is faster than full sort for top-k
808
  top_k_indices = np.argpartition(similarities, -k-1)[-k-1:-1]
809
- # Sort only the top k (much faster than sorting all)
810
  top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])][::-1]
811
 
812
  similar_models = []
@@ -817,7 +1029,7 @@ async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)):
817
  similar_models.append({
818
  "model_id": row.get('model_id', 'Unknown'),
819
  "similarity": float(similarities[idx]),
820
- "distance": float(1 - similarities[idx]), # Convert similarity to distance
821
  "library_name": row.get('library_name'),
822
  "pipeline_tag": row.get('pipeline_tag'),
823
  "downloads": int(row.get('downloads', 0)),
@@ -843,11 +1055,12 @@ async def get_models_by_semantic_similarity(
843
  Returns models with their similarity scores and coordinates.
844
  Useful for exploring the embedding space around a specific model.
845
  """
846
- global df, embedder, embeddings, reduced_embeddings
847
-
848
- if df is None or embeddings is None:
849
  raise HTTPException(status_code=503, detail="Data not loaded")
850
 
 
 
 
851
  # Find the query model
852
  if 'model_id' in df.index.names or df.index.name == 'model_id':
853
  try:
@@ -863,7 +1076,6 @@ async def get_models_by_semantic_similarity(
863
 
864
  query_embedding = embeddings[model_idx]
865
 
866
- # Filter by downloads/likes first for performance
867
  filtered_df = data_loader.filter_data(
868
  df=df,
869
  min_downloads=min_downloads,
@@ -873,32 +1085,26 @@ async def get_models_by_semantic_similarity(
873
  pipeline_tags=None
874
  )
875
 
876
- # Get indices of filtered models
877
  if df.index.name == 'model_id' or 'model_id' in df.index.names:
878
  filtered_indices = [df.index.get_loc(idx) for idx in filtered_df.index]
879
  filtered_indices = np.array(filtered_indices, dtype=int)
880
  else:
881
  filtered_indices = filtered_df.index.values.astype(int)
882
 
883
- # Calculate similarities only for filtered models
884
  filtered_embeddings = embeddings[filtered_indices]
885
  from sklearn.metrics.pairwise import cosine_similarity
886
  query_embedding_2d = query_embedding.reshape(1, -1)
887
  similarities = cosine_similarity(query_embedding_2d, filtered_embeddings)[0]
888
 
889
- # Get top k similar models
890
  top_k_local_indices = np.argpartition(similarities, -k)[-k:]
891
  top_k_local_indices = top_k_local_indices[np.argsort(similarities[top_k_local_indices])][::-1]
892
 
893
- # Get reduced embeddings for visualization
894
  if reduced_embeddings is None:
895
  raise HTTPException(status_code=503, detail="Reduced embeddings not ready")
896
 
897
- # Map back to original indices
898
  top_k_original_indices = filtered_indices[top_k_local_indices]
899
  top_k_reduced = reduced_embeddings[top_k_original_indices]
900
 
901
- # Build response
902
  similar_models = []
903
  for i, orig_idx in enumerate(top_k_original_indices):
904
  row = df.iloc[orig_idx]
@@ -935,11 +1141,12 @@ async def get_distance(
935
  """
936
  Calculate distance/similarity between two models.
937
  """
938
- global df, embedder, embeddings
939
-
940
- if df is None or embeddings is None:
941
  raise HTTPException(status_code=503, detail="Data not loaded")
942
 
 
 
 
943
  # Find both models - optimized with index lookup
944
  if 'model_id' in df.index.names or df.index.name == 'model_id':
945
  try:
@@ -976,7 +1183,7 @@ async def export_models(model_ids: List[str]):
976
  Export selected models as JSON with full metadata.
977
  """
978
  if df is None:
979
- raise HTTPException(status_code=503, detail="Data not loaded")
980
 
981
  # Optimized export with index lookup
982
  if 'model_id' in df.index.names or df.index.name == 'model_id':
@@ -991,7 +1198,6 @@ async def export_models(model_ids: List[str]):
991
  if len(exported) == 0:
992
  return {"models": []}
993
 
994
- # Use list comprehension for faster building
995
  models = [
996
  {
997
  "model_id": str(row.get('model_id', '')),
@@ -1029,12 +1235,10 @@ async def get_cooccurrence_network(
1029
  Returns network graph data suitable for visualization.
1030
  """
1031
  if df is None:
1032
- raise HTTPException(status_code=503, detail="Data not loaded")
1033
 
1034
  try:
1035
  network_builder = ModelNetworkBuilder(df)
1036
-
1037
- # Get top models by field
1038
  top_models = network_builder.get_top_models_by_field(
1039
  library=library,
1040
  pipeline_tag=pipeline_tag,
@@ -1051,14 +1255,11 @@ async def get_cooccurrence_network(
1051
  }
1052
 
1053
  model_ids = [mid for mid, _ in top_models]
1054
-
1055
- # Build co-occurrence network
1056
  graph = network_builder.build_cooccurrence_network(
1057
  model_ids=model_ids,
1058
  cooccurrence_method=cooccurrence_method
1059
  )
1060
 
1061
- # Convert to JSON-serializable format
1062
  nodes = []
1063
  for node_id, attrs in graph.nodes(data=True):
1064
  nodes.append({
@@ -1086,45 +1287,70 @@ async def get_cooccurrence_network(
1086
  "links": links,
1087
  "statistics": stats
1088
  }
1089
-
1090
- except Exception as e:
1091
  raise HTTPException(status_code=500, detail=f"Error building network: {str(e)}")
1092
 
1093
 
1094
  @app.get("/api/network/family/{model_id}")
1095
  async def get_family_network(
1096
  model_id: str,
1097
- max_depth: int = Query(5, ge=1, le=10)
 
 
1098
  ):
1099
  """
1100
  Build family tree network for a model (directed graph).
1101
- Returns network graph data showing parent-child relationships.
 
1102
  """
1103
  if df is None:
1104
- raise HTTPException(status_code=503, detail="Data not loaded")
1105
 
1106
  try:
 
 
 
 
1107
  network_builder = ModelNetworkBuilder(df)
1108
  graph = network_builder.build_family_tree_network(
1109
  root_model_id=model_id,
1110
- max_depth=max_depth
 
 
1111
  )
1112
 
1113
- # Convert to JSON-serializable format
1114
  nodes = []
1115
  for node_id, attrs in graph.nodes(data=True):
1116
  nodes.append({
1117
  "id": node_id,
1118
  "title": attrs.get('title', node_id),
1119
- "freq": attrs.get('freq', 0)
 
 
 
 
1120
  })
1121
 
1122
  links = []
1123
- for source, target in graph.edges():
1124
- links.append({
1125
  "source": source,
1126
- "target": target
1127
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
1128
 
1129
  stats = network_builder.get_network_statistics(graph)
1130
 
@@ -1134,8 +1360,8 @@ async def get_family_network(
1134
  "statistics": stats,
1135
  "root_model": model_id
1136
  }
1137
-
1138
- except Exception as e:
1139
  raise HTTPException(status_code=500, detail=f"Error building family network: {str(e)}")
1140
 
1141
 
@@ -1150,11 +1376,10 @@ async def get_model_neighbors(
1150
  Similar to graph database queries for finding connected nodes.
1151
  """
1152
  if df is None:
1153
- raise HTTPException(status_code=503, detail="Data not loaded")
1154
 
1155
  try:
1156
  network_builder = ModelNetworkBuilder(df)
1157
- # Build network for top models (for performance)
1158
  top_models = network_builder.get_top_models_by_field(n=1000)
1159
  model_ids = [mid for mid, _ in top_models]
1160
  graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
@@ -1171,8 +1396,8 @@ async def get_model_neighbors(
1171
  "neighbors": neighbors,
1172
  "count": len(neighbors)
1173
  }
1174
-
1175
- except Exception as e:
1176
  raise HTTPException(status_code=500, detail=f"Error finding neighbors: {str(e)}")
1177
 
1178
 
@@ -1187,7 +1412,7 @@ async def find_path_between_models(
1187
  Similar to graph database path queries.
1188
  """
1189
  if df is None:
1190
- raise HTTPException(status_code=503, detail="Data not loaded")
1191
 
1192
  try:
1193
  network_builder = ModelNetworkBuilder(df)
@@ -1235,7 +1460,7 @@ async def search_by_cooccurrence(
1235
  Similar to graph database queries for co-assignment patterns.
1236
  """
1237
  if df is None:
1238
- raise HTTPException(status_code=503, detail="Data not loaded")
1239
 
1240
  try:
1241
  network_builder = ModelNetworkBuilder(df)
@@ -1272,7 +1497,7 @@ async def get_model_relationships(
1272
  Similar to graph database relationship queries.
1273
  """
1274
  if df is None:
1275
- raise HTTPException(status_code=503, detail="Data not loaded")
1276
 
1277
  try:
1278
  network_builder = ModelNetworkBuilder(df)
@@ -1297,32 +1522,57 @@ async def get_model_relationships(
1297
  async def get_current_model_count(
1298
  use_cache: bool = Query(True),
1299
  force_refresh: bool = Query(False),
1300
- use_dataset_snapshot: bool = Query(False)
 
1301
  ):
1302
  """
1303
  Get the current number of models on Hugging Face Hub.
1304
- Fetches live data from the Hub API or uses dataset snapshot (faster but may be outdated).
1305
 
1306
  Query Parameters:
1307
  use_cache: Use cached results if available (default: True)
1308
  force_refresh: Force refresh even if cache is valid (default: False)
1309
- use_dataset_snapshot: Use dataset snapshot instead of API (faster, default: False)
 
1310
  """
1311
  try:
 
 
1312
  if use_dataset_snapshot:
1313
- # Use improved tracker with dataset snapshot (like ai-ecosystem repo)
1314
- tracker = get_improved_tracker()
1315
- count_data = tracker.get_count_from_dataset_snapshot()
1316
  if count_data is None:
1317
- # Fallback to API if dataset unavailable
1318
- count_data = tracker.get_current_model_count(use_cache=use_cache, force_refresh=force_refresh)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1319
  else:
1320
- # Use improved tracker with API (has caching)
1321
- tracker = get_improved_tracker()
1322
- count_data = tracker.get_current_model_count(use_cache=use_cache, force_refresh=force_refresh)
1323
 
1324
  return count_data
1325
  except Exception as e:
 
1326
  raise HTTPException(status_code=500, detail=f"Error fetching model count: {str(e)}")
1327
 
1328
 
@@ -1343,7 +1593,7 @@ async def get_historical_model_counts(
1343
  try:
1344
  from datetime import datetime
1345
 
1346
- tracker = get_improved_tracker()
1347
 
1348
  start = None
1349
  end = None
@@ -1373,7 +1623,7 @@ async def get_historical_model_counts(
1373
  async def get_latest_model_count():
1374
  """Get the most recently recorded model count from database."""
1375
  try:
1376
- tracker = get_improved_tracker()
1377
  latest = tracker.get_latest_count()
1378
  if latest is None:
1379
  raise HTTPException(status_code=404, detail="No model counts recorded yet")
@@ -1397,16 +1647,14 @@ async def record_model_count(
1397
  use_dataset_snapshot: Use dataset snapshot instead of API (faster, default: False)
1398
  """
1399
  try:
1400
- tracker = get_improved_tracker()
1401
 
1402
- # Fetch and record in background to avoid blocking
1403
  def record():
1404
  if use_dataset_snapshot:
1405
  count_data = tracker.get_count_from_dataset_snapshot()
1406
  if count_data:
1407
  tracker.record_count(count_data, source="dataset_snapshot")
1408
  else:
1409
- # Fallback to API
1410
  count_data = tracker.get_current_model_count(use_cache=False)
1411
  tracker.record_count(count_data, source="api")
1412
  else:
@@ -1433,7 +1681,7 @@ async def get_growth_stats(days: int = Query(7, ge=1, le=365)):
1433
  days: Number of days to analyze
1434
  """
1435
  try:
1436
- tracker = get_improved_tracker()
1437
  stats = tracker.get_growth_stats(days)
1438
  return stats
1439
  except Exception as e:
@@ -1455,12 +1703,11 @@ async def export_network_graphml(
1455
  Similar to Open Syllabus graph export functionality.
1456
  """
1457
  if df is None:
1458
- raise HTTPException(status_code=503, detail="Data not loaded")
1459
 
1460
  try:
1461
  network_builder = ModelNetworkBuilder(df)
1462
 
1463
- # Get top models by field
1464
  top_models = network_builder.get_top_models_by_field(
1465
  library=library,
1466
  pipeline_tag=pipeline_tag,
@@ -1473,29 +1720,24 @@ async def export_network_graphml(
1473
  raise HTTPException(status_code=404, detail="No models found matching criteria")
1474
 
1475
  model_ids = [mid for mid, _ in top_models]
1476
-
1477
- # Build co-occurrence network
1478
  graph = network_builder.build_cooccurrence_network(
1479
  model_ids=model_ids,
1480
  cooccurrence_method=cooccurrence_method
1481
  )
1482
 
1483
- # Create temporary file
1484
  with tempfile.NamedTemporaryFile(mode='w', suffix='.graphml', delete=False) as tmp_file:
1485
  tmp_path = tmp_file.name
1486
  network_builder.export_graphml(graph, tmp_path)
1487
 
1488
- # Schedule cleanup after response is sent
1489
  background_tasks.add_task(os.unlink, tmp_path)
1490
 
1491
- # Return file for download
1492
  return FileResponse(
1493
  tmp_path,
1494
  media_type='application/xml',
1495
  filename=f'network_{cooccurrence_method}_{n}_models.graphml'
1496
  )
1497
-
1498
- except Exception as e:
1499
  raise HTTPException(status_code=500, detail=f"Error exporting network: {str(e)}")
1500
 
1501
 
@@ -1506,7 +1748,7 @@ async def get_model_papers(model_id: str):
1506
  Extracts arXiv IDs from model tags and fetches paper information.
1507
  """
1508
  if df is None:
1509
- raise HTTPException(status_code=503, detail="Data not loaded")
1510
 
1511
  model = df[df.get('model_id', '') == model_id]
1512
  if len(model) == 0:
@@ -1535,36 +1777,131 @@ async def get_model_papers(model_id: str):
1535
  }
1536
 
1537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1538
  @app.get("/api/model/{model_id}/files")
1539
  async def get_model_files(model_id: str, branch: str = Query("main")):
1540
  """
1541
  Get file tree for a model from Hugging Face.
1542
  Proxies the request to avoid CORS issues.
 
1543
  """
 
 
 
 
 
1544
  try:
1545
- # Try main branch first, then master
1546
- branches_to_try = [branch, "main", "master"] if branch not in ["main", "master"] else [branch, "main" if branch == "master" else "master"]
1547
-
1548
- async with httpx.AsyncClient(timeout=10.0) as client:
1549
  for branch_name in branches_to_try:
1550
  try:
1551
  url = f"https://huggingface.co/api/models/{model_id}/tree/{branch_name}"
1552
  response = await client.get(url)
 
1553
  if response.status_code == 200:
1554
- return response.json()
1555
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1556
  continue
1557
 
1558
- raise HTTPException(status_code=404, detail="File tree not found for this model")
 
 
 
 
 
1559
  except httpx.TimeoutException:
1560
- raise HTTPException(status_code=504, detail="Request to Hugging Face timed out")
 
 
 
 
 
1561
  except Exception as e:
1562
- raise HTTPException(status_code=500, detail=f"Error fetching file tree: {str(e)}")
 
 
 
 
1563
 
1564
 
1565
  if __name__ == "__main__":
1566
  import uvicorn
1567
- # Use PORT environment variable for cloud platforms (Railway, Render, Heroku)
1568
  port = int(os.getenv("PORT", 8000))
1569
  uvicorn.run(app, host="0.0.0.0", port=port)
1570
 
 
 
 
 
1
  import sys
2
  import os
3
+ import pickle
4
+ import tempfile
5
+ import logging
6
+ from typing import Optional, List, Dict
7
+ from datetime import datetime, timedelta
8
 
9
+ import pandas as pd
10
+ import numpy as np
11
+ import httpx
12
  from fastapi import FastAPI, HTTPException, Query, BackgroundTasks, Request
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.middleware.gzip import GZipMiddleware
15
  from fastapi.responses import FileResponse, JSONResponse
16
  from fastapi.exceptions import RequestValidationError
17
  from starlette.exceptions import HTTPException as StarletteHTTPException
 
 
 
18
  from pydantic import BaseModel
19
  from umap import UMAP
 
 
 
20
 
21
  from utils.data_loader import ModelDataLoader
22
  from utils.embeddings import ModelEmbedder
23
  from utils.dimensionality_reduction import DimensionReducer
24
  from utils.network_analysis import ModelNetworkBuilder
25
+ from utils.graph_embeddings import GraphEmbedder
26
  from services.model_tracker import get_tracker
 
27
  from services.arxiv_api import extract_arxiv_ids, fetch_arxiv_papers
28
+ from core.config import settings
29
+ from core.exceptions import DataNotLoadedError, EmbeddingsNotReadyError
30
+ from models.schemas import ModelPoint
31
+ from utils.family_tree import calculate_family_depths
32
+ import api.dependencies as deps
33
+ from api.routes import models, stats, clusters
34
+
35
+ # Create aliases for backward compatibility with existing routes
36
+ # Note: These are set at module load time and may be None initially
37
+ # Functions should access via deps.* to get current values
38
+ data_loader = deps.data_loader
39
+
40
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
41
+ if backend_dir not in sys.path:
42
+ sys.path.insert(0, backend_dir)
43
 
44
+ logger = logging.getLogger(__name__)
45
+
46
+ app = FastAPI(title="HF Model Ecosystem API", version="2.0.0")
47
 
48
  app.add_middleware(GZipMiddleware, minimum_size=1000)
49
 
50
+ CORS_HEADERS = {
51
+ "Access-Control-Allow-Origin": "*",
52
+ "Access-Control-Allow-Methods": "*",
53
+ "Access-Control-Allow-Headers": "*",
54
+ }
55
+
56
  @app.exception_handler(Exception)
57
  async def global_exception_handler(request: Request, exc: Exception):
58
+ logger.exception("Unhandled exception", exc_info=exc)
 
 
 
 
 
59
  return JSONResponse(
60
  status_code=500,
61
+ content={"detail": "Internal server error"},
62
+ headers=CORS_HEADERS,
 
 
 
 
63
  )
64
 
65
  @app.exception_handler(StarletteHTTPException)
66
  async def http_exception_handler(request: Request, exc: StarletteHTTPException):
 
67
  return JSONResponse(
68
  status_code=exc.status_code,
69
  content={"detail": exc.detail},
70
+ headers=CORS_HEADERS,
 
 
 
 
71
  )
72
 
73
  @app.exception_handler(RequestValidationError)
74
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
 
75
  return JSONResponse(
76
  status_code=422,
77
  content={"detail": exc.errors()},
78
+ headers=CORS_HEADERS,
 
 
 
 
79
  )
80
 
81
+ if settings.ALLOW_ALL_ORIGINS:
 
 
 
 
 
 
82
  app.add_middleware(
83
  CORSMiddleware,
84
+ allow_origins=["*"],
85
+ allow_credentials=False,
86
  allow_methods=["*"],
87
  allow_headers=["*"],
88
  )
89
  else:
90
  app.add_middleware(
91
  CORSMiddleware,
92
+ allow_origins=["http://localhost:3000", settings.FRONTEND_URL],
 
 
 
 
 
93
  allow_credentials=True,
94
  allow_methods=["*"],
95
  allow_headers=["*"],
96
  )
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Include routers
100
+ app.include_router(models.router)
101
+ app.include_router(stats.router)
102
+ app.include_router(clusters.router)
103
 
104
  @app.on_event("startup")
105
  async def startup_event():
106
+ # All variables are accessed via deps module, no need for global declarations
 
107
 
 
108
  backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
109
  root_dir = os.path.dirname(backend_dir)
110
  cache_dir = os.path.join(root_dir, "cache")
111
  os.makedirs(cache_dir, exist_ok=True)
112
 
113
  embeddings_cache = os.path.join(cache_dir, "embeddings.pkl")
114
+ graph_embeddings_cache = os.path.join(cache_dir, "graph_embeddings.pkl")
115
+ combined_embeddings_cache = os.path.join(cache_dir, "combined_embeddings.pkl")
116
  reduced_cache_umap = os.path.join(cache_dir, "reduced_umap_3d.pkl")
117
+ reduced_cache_umap_graph = os.path.join(cache_dir, "reduced_umap_3d_graph.pkl")
118
  reducer_cache_umap = os.path.join(cache_dir, "reducer_umap_3d.pkl")
119
+ reducer_cache_umap_graph = os.path.join(cache_dir, "reducer_umap_3d_graph.pkl")
120
 
121
+ sample_size = settings.get_sample_size()
122
+ if sample_size:
123
+ logger.info(f"Loading limited dataset: {sample_size} models (SAMPLE_SIZE={sample_size})")
124
  else:
125
+ logger.info("No SAMPLE_SIZE set, loading full dataset")
126
+
127
+ deps.df = deps.data_loader.load_data(sample_size=sample_size)
128
+ deps.df = deps.data_loader.preprocess_for_embedding(deps.df)
129
+
130
+ if 'model_id' in deps.df.columns:
131
+ deps.df.set_index('model_id', drop=False, inplace=True)
 
132
  for col in ['downloads', 'likes']:
133
+ if col in deps.df.columns:
134
+ deps.df[col] = pd.to_numeric(deps.df[col], errors='coerce').fillna(0).astype(int)
135
 
136
+ deps.embedder = ModelEmbedder()
137
 
138
+ # Load or generate text embeddings
139
  if os.path.exists(embeddings_cache):
140
  try:
141
+ deps.embeddings = deps.embedder.load_embeddings(embeddings_cache)
142
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
143
+ logger.warning(f"Failed to load cached embeddings: {e}")
144
+ deps.embeddings = None
145
+
146
+ if deps.embeddings is None:
147
+ texts = deps.df['combined_text'].tolist()
148
+ deps.embeddings = deps.embedder.generate_embeddings(texts, batch_size=128)
149
+ deps.embedder.save_embeddings(deps.embeddings, embeddings_cache)
150
+
151
+ # Initialize graph embedder and generate graph embeddings (optional, lazy-loaded)
152
+ if settings.USE_GRAPH_EMBEDDINGS:
153
+ try:
154
+ deps.graph_embedder = GraphEmbedder()
155
+ logger.info("Building family graph for graph embeddings...")
156
+ graph = deps.graph_embedder.build_family_graph(deps.df)
157
+
158
+ if os.path.exists(graph_embeddings_cache):
159
+ try:
160
+ deps.graph_embeddings_dict = deps.graph_embedder.load_embeddings(graph_embeddings_cache)
161
+ logger.info(f"Loaded cached graph embeddings for {len(deps.graph_embeddings_dict)} models")
162
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
163
+ logger.warning(f"Failed to load cached graph embeddings: {e}")
164
+ deps.graph_embeddings_dict = None
165
+
166
+ if deps.graph_embeddings_dict is None or len(deps.graph_embeddings_dict) == 0:
167
+ logger.info("Generating graph embeddings (this may take a while)...")
168
+ deps.graph_embeddings_dict = deps.graph_embedder.generate_graph_embeddings(graph, workers=4)
169
+ if deps.graph_embeddings_dict:
170
+ deps.graph_embedder.save_embeddings(deps.graph_embeddings_dict, graph_embeddings_cache)
171
+ logger.info(f"Generated graph embeddings for {len(deps.graph_embeddings_dict)} models")
172
+
173
+ # Combine text and graph embeddings
174
+ if deps.graph_embeddings_dict and len(deps.graph_embeddings_dict) > 0:
175
+ model_ids = deps.df['model_id'].astype(str).tolist()
176
+ if os.path.exists(combined_embeddings_cache):
177
+ try:
178
+ with open(combined_embeddings_cache, 'rb') as f:
179
+ deps.combined_embeddings = pickle.load(f)
180
+ logger.info("Loaded cached combined embeddings")
181
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
182
+ logger.warning(f"Failed to load cached combined embeddings: {e}")
183
+ deps.combined_embeddings = None
184
+
185
+ if deps.combined_embeddings is None:
186
+ logger.info("Combining text and graph embeddings...")
187
+ deps.combined_embeddings = deps.graph_embedder.combine_embeddings(
188
+ deps.embeddings, deps.graph_embeddings_dict, model_ids,
189
+ text_weight=0.7, graph_weight=0.3
190
+ )
191
+ with open(combined_embeddings_cache, 'wb') as f:
192
+ pickle.dump(deps.combined_embeddings, f)
193
+ logger.info("Combined embeddings saved")
194
  except Exception as e:
195
+ logger.warning(f"Graph embeddings not available: {e}. Continuing with text-only embeddings.")
196
+ deps.graph_embedder = None
197
+ deps.graph_embeddings_dict = None
198
+ deps.combined_embeddings = None
 
 
199
 
200
+ # Initialize reducer for text embeddings
201
+ deps.reducer = DimensionReducer(method="umap", n_components=3)
202
 
203
  if os.path.exists(reduced_cache_umap) and os.path.exists(reducer_cache_umap):
204
  try:
 
205
  with open(reduced_cache_umap, 'rb') as f:
206
+ deps.reduced_embeddings = pickle.load(f)
207
+ deps.reducer.load_reducer(reducer_cache_umap)
208
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
209
+ logger.warning(f"Failed to load cached reduced embeddings: {e}")
210
+ deps.reduced_embeddings = None
211
+
212
+ if deps.reduced_embeddings is None:
213
+ deps.reducer.reducer = UMAP(
214
  n_components=3,
215
  n_neighbors=30,
216
  min_dist=0.3,
 
220
  low_memory=True,
221
  spread=1.5
222
  )
223
+ deps.reduced_embeddings = deps.reducer.fit_transform(deps.embeddings)
 
224
  with open(reduced_cache_umap, 'wb') as f:
225
+ pickle.dump(deps.reduced_embeddings, f)
226
+ deps.reducer.save_reducer(reducer_cache_umap)
 
 
 
 
 
 
 
 
 
227
 
228
+ # Initialize reducer for graph-aware embeddings if available
229
+ if deps.combined_embeddings is not None:
230
+ reducer_graph = DimensionReducer(method="umap", n_components=3)
 
 
 
 
231
 
232
+ if os.path.exists(reduced_cache_umap_graph) and os.path.exists(reducer_cache_umap_graph):
233
+ try:
234
+ with open(reduced_cache_umap_graph, 'rb') as f:
235
+ deps.reduced_embeddings_graph = pickle.load(f)
236
+ reducer_graph.load_reducer(reducer_cache_umap_graph)
237
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
238
+ logger.warning(f"Failed to load cached graph-aware reduced embeddings: {e}")
239
+ deps.reduced_embeddings_graph = None
 
 
 
 
 
 
 
240
 
241
+ if deps.reduced_embeddings_graph is None:
242
+ reducer_graph.reducer = UMAP(
243
+ n_components=3,
244
+ n_neighbors=30,
245
+ min_dist=0.3,
246
+ metric='cosine',
247
+ random_state=42,
248
+ n_jobs=-1,
249
+ low_memory=True,
250
+ spread=1.5
251
+ )
252
+ deps.reduced_embeddings_graph = reducer_graph.fit_transform(deps.combined_embeddings)
253
+ with open(reduced_cache_umap_graph, 'wb') as f:
254
+ pickle.dump(deps.reduced_embeddings_graph, f)
255
+ reducer_graph.save_reducer(reducer_cache_umap_graph)
256
+ logger.info("Graph-aware embeddings reduced and cached")
257
 
258
+ # Update module-level aliases
259
+ df = deps.df
260
+ embedder = deps.embedder
261
+ graph_embedder = deps.graph_embedder
262
+ reducer = deps.reducer
263
+ embeddings = deps.embeddings
264
+ graph_embeddings_dict = deps.graph_embeddings_dict
265
+ combined_embeddings = deps.combined_embeddings
266
+ reduced_embeddings = deps.reduced_embeddings
267
+ reduced_embeddings_graph = deps.reduced_embeddings_graph
268
+
269
+
270
+ from utils.family_tree import calculate_family_depths
271
 
272
 
273
  def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray:
 
 
 
 
274
  from sklearn.cluster import KMeans
275
 
276
  n_samples = len(reduced_embeddings)
 
278
  n_clusters = max(1, n_samples // 10)
279
 
280
  kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
281
+ return kmeans.fit_predict(reduced_embeddings)
 
282
 
283
 
284
  @app.get("/")
 
293
  search_query: Optional[str] = Query(None),
294
  color_by: str = Query("library_name"),
295
  size_by: str = Query("downloads"),
296
+ max_points: Optional[int] = Query(None),
297
+ projection_method: str = Query("umap"),
298
+ base_models_only: bool = Query(False),
299
+ max_hierarchy_depth: Optional[int] = Query(None, ge=0, description="Filter to models at or below this hierarchy depth."),
300
+ use_graph_embeddings: bool = Query(False, description="Use graph-aware embeddings that respect family tree structure")
301
  ):
302
+ if deps.df is None:
303
+ raise DataNotLoadedError()
 
 
 
 
 
 
 
 
 
304
 
305
+ df = deps.df
 
306
 
307
  # Filter data
308
  filtered_df = data_loader.filter_data(
 
322
  (filtered_df['parent_model'].astype(str) == 'nan')
323
  ]
324
 
325
+ if max_hierarchy_depth is not None:
326
+ family_depths = calculate_family_depths(df)
327
+ filtered_df = filtered_df[
328
+ filtered_df['model_id'].astype(str).map(lambda x: family_depths.get(x, 0) <= max_hierarchy_depth)
329
+ ]
330
+
331
  filtered_count = len(filtered_df)
332
 
333
  if len(filtered_df) == 0:
 
338
  }
339
 
340
  if max_points is not None and len(filtered_df) > max_points:
 
 
341
  if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
342
+ # Sample proportionally by library, preserving all columns
343
+ sampled_dfs = []
344
+ for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
345
+ n_samples = max(1, int(max_points * len(group) / len(filtered_df)))
346
+ sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
347
+ filtered_df = pd.concat(sampled_dfs, ignore_index=True)
348
  if len(filtered_df) > max_points:
349
+ filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
350
+ else:
351
+ filtered_df = filtered_df.reset_index(drop=True)
352
  else:
353
+ filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
 
 
 
354
 
355
+ # Determine which embeddings to use
356
+ if use_graph_embeddings and combined_embeddings is not None:
357
+ current_embeddings = combined_embeddings
358
+ current_reduced = reduced_embeddings_graph
359
+ embedding_type = "graph-aware"
360
+ else:
361
+ if embeddings is None:
362
+ raise EmbeddingsNotReadyError()
363
+ current_embeddings = embeddings
364
+ current_reduced = reduced_embeddings
365
+ embedding_type = "text-only"
366
+
367
+ # Handle reduced embeddings loading/generation
368
+ if current_reduced is None or (reducer and reducer.method != projection_method.lower()):
369
  backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
370
  root_dir = os.path.dirname(backend_dir)
371
  cache_dir = os.path.join(root_dir, "cache")
372
+ cache_suffix = "_graph" if use_graph_embeddings and combined_embeddings is not None else ""
373
+ reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d{cache_suffix}.pkl")
374
+ reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d{cache_suffix}.pkl")
375
 
376
  if os.path.exists(reduced_cache) and os.path.exists(reducer_cache):
377
  try:
 
378
  with open(reduced_cache, 'rb') as f:
379
+ current_reduced = pickle.load(f)
380
  if reducer is None or reducer.method != projection_method.lower():
381
  reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
382
  reducer.load_reducer(reducer_cache)
383
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
384
+ logger.warning(f"Failed to load cached reduced embeddings: {e}")
385
+ current_reduced = None
386
 
387
+ if current_reduced is None:
388
  if reducer is None or reducer.method != projection_method.lower():
389
  reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
390
  if projection_method.lower() == "umap":
 
398
  low_memory=True,
399
  spread=1.5
400
  )
401
+ current_reduced = reducer.fit_transform(current_embeddings)
 
402
  with open(reduced_cache, 'wb') as f:
403
+ pickle.dump(current_reduced, f)
404
  reducer.save_reducer(reducer_cache)
405
+
406
+ # Update global variable
407
+ if use_graph_embeddings and deps.combined_embeddings is not None:
408
+ deps.reduced_embeddings_graph = current_reduced
409
+ else:
410
+ deps.reduced_embeddings = current_reduced
411
+
412
+ # Get indices for filtered data
413
+ # Use model_id column to map between filtered_df and original df
414
+ # This is safer than using index positions which can change after filtering
415
+ filtered_model_ids = filtered_df['model_id'].astype(str).values
416
 
417
+ # Map model_ids to positions in original df
 
 
418
  if df.index.name == 'model_id' or 'model_id' in df.index.names:
419
+ # When df is indexed by model_id, use get_loc directly
420
+ filtered_indices = []
421
+ for model_id in filtered_model_ids:
422
+ try:
423
+ pos = df.index.get_loc(model_id)
424
+ # Handle both single position and array of positions
425
+ if isinstance(pos, (int, np.integer)):
426
+ filtered_indices.append(int(pos))
427
+ elif isinstance(pos, (slice, np.ndarray)):
428
+ # If multiple matches, take first
429
+ if isinstance(pos, slice):
430
+ filtered_indices.append(int(pos.start))
431
+ else:
432
+ filtered_indices.append(int(pos[0]))
433
+ except (KeyError, TypeError):
434
+ continue
435
+ filtered_indices = np.array(filtered_indices, dtype=np.int32)
436
  else:
437
+ # When df is not indexed by model_id, find positions by matching model_id column
438
+ df_model_ids = df['model_id'].astype(str).values
439
+ model_id_to_pos = {mid: pos for pos, mid in enumerate(df_model_ids)}
440
+ filtered_indices = np.array([
441
+ model_id_to_pos[mid] for mid in filtered_model_ids
442
+ if mid in model_id_to_pos
443
+ ], dtype=np.int32)
444
+
445
+ if len(filtered_indices) == 0:
446
+ return {
447
+ "models": [],
448
+ "embedding_type": embedding_type,
449
+ "filtered_count": filtered_count,
450
+ "returned_count": 0
451
+ }
452
 
453
+ filtered_reduced = current_reduced[filtered_indices]
454
  family_depths = calculate_family_depths(df)
455
 
456
+ # Use appropriate embeddings for clustering
457
+ clustering_embeddings = current_reduced
458
+ # Compute clusters if not already computed or if size changed
459
+ if models.cluster_labels is None or len(models.cluster_labels) != len(clustering_embeddings):
460
+ models.cluster_labels = compute_clusters(clustering_embeddings, n_clusters=min(50, len(clustering_embeddings) // 100))
461
 
462
+ # Handle case where cluster_labels might not match filtered data yet
463
+ if models.cluster_labels is not None and len(models.cluster_labels) > 0:
464
+ if len(filtered_indices) <= len(models.cluster_labels):
465
+ filtered_clusters = models.cluster_labels[filtered_indices]
466
+ else:
467
+ # Fallback: use first cluster for all if indices don't match
468
+ filtered_clusters = np.zeros(len(filtered_indices), dtype=int)
469
+ else:
470
+ filtered_clusters = np.zeros(len(filtered_indices), dtype=int)
471
 
 
 
472
  model_ids = filtered_df['model_id'].astype(str).values
473
+ library_names = filtered_df.get('library_name', pd.Series([None] * len(filtered_df))).values
474
+ pipeline_tags = filtered_df.get('pipeline_tag', pd.Series([None] * len(filtered_df))).values
475
+ downloads_arr = filtered_df.get('downloads', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
476
+ likes_arr = filtered_df.get('likes', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
477
+ trending_scores = filtered_df.get('trendingScore', pd.Series([None] * len(filtered_df))).values
478
+ tags_arr = filtered_df.get('tags', pd.Series([None] * len(filtered_df))).values
479
+ parent_models = filtered_df.get('parent_model', pd.Series([None] * len(filtered_df))).values
480
+ licenses_arr = filtered_df.get('licenses', pd.Series([None] * len(filtered_df))).values
481
+ created_at_arr = filtered_df.get('createdAt', pd.Series([None] * len(filtered_df))).values
482
+
483
  x_coords = filtered_reduced[:, 0].astype(float)
484
  y_coords = filtered_reduced[:, 1].astype(float)
485
  z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float)
 
 
486
  models = [
487
  ModelPoint(
488
  model_id=model_ids[idx],
 
498
  parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None,
499
  licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None,
500
  family_depth=family_depths.get(model_ids[idx], None),
501
+ cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None,
502
+ created_at=str(created_at_arr[idx]) if idx < len(created_at_arr) and pd.notna(created_at_arr[idx]) else None
503
  )
504
  for idx in range(len(filtered_df))
505
  ]
506
 
507
+ # Return models with metadata about embedding type
508
+ return {
509
+ "models": models,
510
+ "embedding_type": embedding_type,
511
+ "filtered_count": filtered_count,
512
+ "returned_count": len(models)
513
+ }
514
 
515
 
516
  @app.get("/api/stats")
517
  async def get_stats():
518
  """Get dataset statistics."""
519
  if df is None:
520
+ raise DataNotLoadedError()
521
 
 
522
  total_models = len(df.index) if hasattr(df, 'index') else len(df)
523
 
524
+ # Get unique licenses with counts
525
+ licenses = {}
526
+ if 'license' in df.columns:
527
+ license_counts = df['license'].value_counts().to_dict()
528
+ licenses = {str(k): int(v) for k, v in license_counts.items() if pd.notna(k) and str(k) != 'nan'}
529
+
530
  return {
531
  "total_models": total_models,
532
  "unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0,
533
  "unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
534
  "unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0, # Alias for clarity
535
+ "unique_licenses": len(licenses),
536
+ "licenses": licenses, # License name -> count mapping
537
  "avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0,
538
  "avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0
539
  }
 
543
  async def get_model_details(model_id: str):
544
  """Get detailed information about a specific model."""
545
  if df is None:
546
+ raise DataNotLoadedError()
547
 
548
  model = df[df.get('model_id', '') == model_id]
549
  if len(model) == 0:
 
551
 
552
  model = model.iloc[0]
553
 
 
554
  tags_str = str(model.get('tags', '')) if pd.notna(model.get('tags')) else ''
555
  arxiv_ids = extract_arxiv_ids(tags_str)
556
 
 
557
  papers = []
558
  if arxiv_ids:
559
  papers = await fetch_arxiv_papers(arxiv_ids[:5]) # Limit to 5 papers
 
573
  }
574
 
575
 
576
+ # Clusters endpoint is handled by routes/clusters.py router
577
+
578
  @app.get("/api/family/stats")
579
  async def get_family_stats():
580
  """
 
582
  Returns family size distribution, depth statistics, model card length by depth, etc.
583
  """
584
  if df is None:
585
+ raise DataNotLoadedError()
586
 
 
587
  family_sizes = {}
588
  root_models = set()
589
 
 
597
  family_sizes[model_id] = 0
598
  else:
599
  parent_id_str = str(parent_id)
 
600
  root = parent_id_str
601
  visited = set()
602
  while root in df.index and pd.notna(df.loc[root].get('parent_model')):
603
  parent = df.loc[root].get('parent_model')
604
  if pd.isna(parent) or str(parent) == 'nan' or str(parent) == '':
605
  break
606
+ if str(parent) in visited:
607
  break
608
  visited.add(root)
609
  root = str(parent)
 
612
  family_sizes[root] = 0
613
  family_sizes[root] += 1
614
 
 
615
  size_distribution = {}
616
  for root, size in family_sizes.items():
617
  size_distribution[size] = size_distribution.get(size, 0) + 1
618
 
 
619
  depths = calculate_family_depths(df)
620
  depth_counts = {}
621
  for depth in depths.values():
622
  depth_counts[depth] = depth_counts.get(depth, 0) + 1
623
 
 
624
  model_card_lengths_by_depth = {}
625
  if 'modelCard' in df.columns:
626
  for idx, row in df.iterrows():
 
633
  model_card_lengths_by_depth[depth] = []
634
  model_card_lengths_by_depth[depth].append(card_length)
635
 
 
636
  model_card_stats = {}
637
  for depth, lengths in model_card_lengths_by_depth.items():
638
  if lengths:
 
657
  }
658
 
659
 
660
+ @app.get("/api/family/path/{model_id}")
661
+ async def get_family_path(
662
+ model_id: str,
663
+ target_id: Optional[str] = Query(None, description="Target model ID. If None, returns path to root.")
664
+ ):
665
+ """
666
+ Get path from model to root or to target model.
667
+ Returns list of model IDs representing the path.
668
+ """
669
+ if df is None:
670
+ raise DataNotLoadedError()
671
+
672
+ model_id_str = str(model_id)
673
+
674
+ if df.index.name == 'model_id':
675
+ if model_id_str not in df.index:
676
+ raise HTTPException(status_code=404, detail="Model not found")
677
+ else:
678
+ model_rows = df[df.get('model_id', '') == model_id_str]
679
+ if len(model_rows) == 0:
680
+ raise HTTPException(status_code=404, detail="Model not found")
681
+
682
+ path = [model_id_str]
683
+ visited = set([model_id_str])
684
+ current = model_id_str
685
+
686
+ if target_id:
687
+ target_str = str(target_id)
688
+ if df.index.name == 'model_id':
689
+ if target_str not in df.index:
690
+ raise HTTPException(status_code=404, detail="Target model not found")
691
+
692
+ while current != target_str and current not in visited:
693
+ try:
694
+ if df.index.name == 'model_id':
695
+ row = df.loc[current]
696
+ else:
697
+ rows = df[df.get('model_id', '') == current]
698
+ if len(rows) == 0:
699
+ break
700
+ row = rows.iloc[0]
701
+
702
+ parent_id = row.get('parent_model')
703
+ if parent_id and pd.notna(parent_id):
704
+ parent_str = str(parent_id)
705
+ if parent_str == target_str:
706
+ path.append(parent_str)
707
+ break
708
+ if parent_str not in visited:
709
+ path.append(parent_str)
710
+ visited.add(parent_str)
711
+ current = parent_str
712
+ else:
713
+ break
714
+ else:
715
+ break
716
+ except (KeyError, IndexError):
717
+ break
718
+ else:
719
+ while True:
720
+ try:
721
+ if df.index.name == 'model_id':
722
+ row = df.loc[current]
723
+ else:
724
+ rows = df[df.get('model_id', '') == current]
725
+ if len(rows) == 0:
726
+ break
727
+ row = rows.iloc[0]
728
+
729
+ parent_id = row.get('parent_model')
730
+ if parent_id and pd.notna(parent_id):
731
+ parent_str = str(parent_id)
732
+ if parent_str not in visited:
733
+ path.append(parent_str)
734
+ visited.add(parent_str)
735
+ current = parent_str
736
+ else:
737
+ break
738
+ else:
739
+ break
740
+ except (KeyError, IndexError):
741
+ break
742
+
743
+ return {
744
+ "path": path,
745
+ "source": model_id_str,
746
+ "target": target_id if target_id else "root",
747
+ "path_length": len(path) - 1
748
+ }
749
+
750
+
751
  @app.get("/api/family/{model_id}")
752
+ async def get_family_tree(
753
+ model_id: str,
754
+ max_depth: Optional[int] = Query(None, ge=1, le=100, description="Maximum depth to traverse. If None, traverses entire tree without limit."),
755
+ max_depth_filter: Optional[int] = Query(None, ge=0, description="Filter results to models at or below this hierarchy depth.")
756
+ ):
757
  """
758
  Get family tree for a model (ancestors and descendants).
759
  Returns the model, its parent chain, and all children.
760
+
761
+ If max_depth is None, traverses the entire family tree without depth limits.
762
  """
763
  if df is None:
764
+ raise DataNotLoadedError()
 
 
 
 
 
 
 
 
765
 
 
766
  if reduced_embeddings is None:
767
  raise HTTPException(status_code=503, detail="Embeddings not ready")
768
 
769
+ model_id_str = str(model_id)
 
 
 
770
 
771
+ if df.index.name == 'model_id':
772
+ if model_id_str not in df.index:
773
+ raise HTTPException(status_code=404, detail="Model not found")
774
+ model_lookup = df.loc
775
+ else:
776
+ model_rows = df[df.get('model_id', '') == model_id_str]
777
+ if len(model_rows) == 0:
778
+ raise HTTPException(status_code=404, detail="Model not found")
779
+ model_lookup = lambda x: df[df.get('model_id', '') == x]
780
+
781
+ from utils.network_analysis import _get_all_parents, _parse_parent_list
782
+
783
+ children_index: Dict[str, List[str]] = {}
784
+ parent_columns = ['parent_model', 'finetune_parent', 'quantized_parent', 'adapter_parent', 'merge_parent']
785
+
786
+ for idx, row in df.iterrows():
787
+ model_id_from_row = str(row.get('model_id', idx))
788
+ all_parents = _get_all_parents(row)
789
+
790
+ for rel_type, parent_list in all_parents.items():
791
+ for parent_str in parent_list:
792
+ if parent_str not in children_index:
793
+ children_index[parent_str] = []
794
+ children_index[parent_str].append(model_id_from_row)
795
+
796
+ visited = set()
797
+
798
+ def get_ancestors(current_id: str, depth: Optional[int]):
799
+ if current_id in visited:
800
+ return
801
+ if depth is not None and depth <= 0:
802
  return
803
  visited.add(current_id)
804
 
805
+ try:
806
+ if df.index.name == 'model_id':
807
+ row = df.loc[current_id]
808
+ else:
809
+ rows = model_lookup(current_id)
810
+ if len(rows) == 0:
811
+ return
812
+ row = rows.iloc[0]
813
+
814
+ all_parents = _get_all_parents(row)
815
+ for rel_type, parent_list in all_parents.items():
816
+ for parent_str in parent_list:
817
+ if parent_str != 'nan' and parent_str != '':
818
+ next_depth = depth - 1 if depth is not None else None
819
+ get_ancestors(parent_str, next_depth)
820
+ except (KeyError, IndexError):
821
+ return
822
 
823
+ def get_descendants(current_id: str, depth: Optional[int]):
824
+ if current_id in visited:
825
+ return
826
+ if depth is not None and depth <= 0:
827
  return
828
  visited.add(current_id)
829
 
830
+ children = children_index.get(current_id, [])
831
+ for child_id in children:
832
+ if child_id not in visited:
833
+ next_depth = depth - 1 if depth is not None else None
834
+ get_descendants(child_id, next_depth)
835
+
836
+ get_ancestors(model_id_str, max_depth)
837
+ visited = set()
838
+ get_descendants(model_id_str, max_depth)
839
+ visited.add(model_id_str)
840
+
841
+ if df.index.name == 'model_id':
 
 
 
 
 
 
 
 
 
 
842
  try:
843
  family_df = df.loc[list(visited)]
844
  except KeyError:
845
+ missing = [v for v in visited if v not in df.index]
846
+ if missing:
847
+ logger.warning(f"Some family members not found in index: {missing}")
848
+ family_df = df.loc[[v for v in visited if v in df.index]]
849
  else:
850
  family_df = df[df.get('model_id', '').isin(visited)]
851
 
852
+ if len(family_df) == 0:
853
+ raise HTTPException(status_code=404, detail="Family tree data not available")
854
+
855
+ family_indices = family_df.index.values
856
+ if len(family_indices) > len(reduced_embeddings):
857
+ raise HTTPException(status_code=503, detail="Embedding indices mismatch")
858
+
859
  family_reduced = reduced_embeddings[family_indices]
860
 
 
861
  family_map = {}
862
  for idx, (i, row) in enumerate(family_df.iterrows()):
863
+ model_id_val = str(row.get('model_id', i))
864
+ parent_id = row.get('parent_model')
865
+ parent_id_str = str(parent_id) if parent_id and pd.notna(parent_id) else None
866
+
867
+ depths = calculate_family_depths(df)
868
+ model_depth = depths.get(model_id_val, 0)
869
+
870
+ if max_depth_filter is not None and model_depth > max_depth_filter:
871
+ continue
872
 
873
  family_map[model_id_val] = {
874
  "model_id": model_id_val,
 
879
  "pipeline_tag": str(row.get('pipeline_tag')) if pd.notna(row.get('pipeline_tag')) else None,
880
  "downloads": int(row.get('downloads', 0)) if pd.notna(row.get('downloads')) else 0,
881
  "likes": int(row.get('likes', 0)) if pd.notna(row.get('likes')) else 0,
882
+ "parent_model": parent_id_str,
883
  "licenses": str(row.get('licenses')) if pd.notna(row.get('licenses')) else None,
884
+ "family_depth": model_depth,
885
  "children": []
886
  }
887
 
 
888
  root_models = []
889
  for model_id_val, model_data in family_map.items():
890
  parent_id = model_data["parent_model"]
 
894
  root_models.append(model_id_val)
895
 
896
  return {
897
+ "root_model": model_id_str,
898
  "family": list(family_map.values()),
899
  "family_map": family_map,
900
  "root_models": root_models
 
903
 
904
  @app.get("/api/search")
905
  async def search_models(
906
+ q: str = Query(..., min_length=1, alias="query"),
907
+ query: str = Query(None, min_length=1),
908
+ limit: int = Query(20, ge=1, le=100),
909
  graph_aware: bool = Query(False),
910
  include_neighbors: bool = Query(True)
911
  ):
 
914
  Enhanced with graph-aware search option that includes network relationships.
915
  """
916
  if df is None:
917
+ raise DataNotLoadedError()
918
+
919
+ # Support both 'q' and 'query' parameters
920
+ search_query = query or q
921
 
922
  if graph_aware:
 
923
  try:
924
  network_builder = ModelNetworkBuilder(df)
 
925
  top_models = network_builder.get_top_models_by_field(n=1000)
926
  model_ids = [mid for mid, _ in top_models]
927
  graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
928
 
929
  results = network_builder.search_graph_aware(
930
+ query=search_query,
931
  graph=graph,
932
+ max_results=limit,
933
  include_neighbors=include_neighbors
934
  )
935
 
936
+ return {"results": results, "search_type": "graph_aware", "query": search_query}
937
+ except (ValueError, KeyError, AttributeError) as e:
938
+ logger.warning(f"Graph-aware search failed, falling back to basic search: {e}")
939
+
940
+ query_lower = search_query.lower()
941
+
942
+ # Enhanced search: search model_id, org, tags, library, pipeline
943
+ model_id_col = df.get('model_id', '').astype(str).str.lower()
944
+ library_col = df.get('library_name', '').astype(str).str.lower()
945
+ pipeline_col = df.get('pipeline_tag', '').astype(str).str.lower()
946
+ tags_col = df.get('tags', '').astype(str).str.lower()
947
+ license_col = df.get('license', '').astype(str).str.lower()
948
+
949
+ # Extract org from model_id
950
+ org_col = model_id_col.str.split('/').str[0]
951
+
952
+ # Multi-field search
953
+ mask = (
954
+ model_id_col.str.contains(query_lower, na=False) |
955
+ org_col.str.contains(query_lower, na=False) |
956
+ library_col.str.contains(query_lower, na=False) |
957
+ pipeline_col.str.contains(query_lower, na=False) |
958
+ tags_col.str.contains(query_lower, na=False) |
959
+ license_col.str.contains(query_lower, na=False)
960
+ )
961
 
962
+ matches = df[mask].head(limit)
 
 
 
963
 
964
  results = []
965
  for _, row in matches.iterrows():
966
+ model_id = str(row.get('model_id', ''))
967
+ org = model_id.split('/')[0] if '/' in model_id else ''
968
+
969
+ # Get coordinates if available
970
+ x = float(row.get('x', 0.0)) if 'x' in row else None
971
+ y = float(row.get('y', 0.0)) if 'y' in row else None
972
+ z = float(row.get('z', 0.0)) if 'z' in row else None
973
+
974
  results.append({
975
+ "model_id": model_id,
976
+ "x": x,
977
+ "y": y,
978
+ "z": z,
979
+ "org": org,
980
+ "library": row.get('library_name'),
981
+ "pipeline": row.get('pipeline_tag'),
982
+ "license": row.get('license') if pd.notna(row.get('license')) else None,
983
  "downloads": int(row.get('downloads', 0)),
984
  "likes": int(row.get('likes', 0)),
985
  "parent_model": row.get('parent_model') if pd.notna(row.get('parent_model')) else None,
986
  "match_type": "direct"
987
  })
988
 
989
+ return {"results": results, "search_type": "basic", "query": search_query}
990
 
991
 
992
  @app.get("/api/similar/{model_id}")
 
995
  Get k-nearest neighbors of a model based on embedding similarity.
996
  Returns similar models with distance scores.
997
  """
998
+ if deps.df is None or deps.embeddings is None:
 
 
999
  raise HTTPException(status_code=503, detail="Data not loaded")
1000
 
1001
+ df = deps.df
1002
+ embeddings = deps.embeddings
1003
+
1004
  if 'model_id' in df.index.names or df.index.name == 'model_id':
1005
  try:
1006
  model_row = df.loc[[model_id]]
 
1014
  model_idx = model_row.index[0]
1015
  model_embedding = embeddings[model_idx]
1016
 
 
1017
  from sklearn.metrics.pairwise import cosine_similarity
 
1018
  model_embedding_2d = model_embedding.reshape(1, -1)
1019
  similarities = cosine_similarity(model_embedding_2d, embeddings)[0]
1020
 
 
 
1021
  top_k_indices = np.argpartition(similarities, -k-1)[-k-1:-1]
 
1022
  top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])][::-1]
1023
 
1024
  similar_models = []
 
1029
  similar_models.append({
1030
  "model_id": row.get('model_id', 'Unknown'),
1031
  "similarity": float(similarities[idx]),
1032
+ "distance": float(1 - similarities[idx]),
1033
  "library_name": row.get('library_name'),
1034
  "pipeline_tag": row.get('pipeline_tag'),
1035
  "downloads": int(row.get('downloads', 0)),
 
1055
  Returns models with their similarity scores and coordinates.
1056
  Useful for exploring the embedding space around a specific model.
1057
  """
1058
+ if deps.df is None or deps.embeddings is None:
 
 
1059
  raise HTTPException(status_code=503, detail="Data not loaded")
1060
 
1061
+ df = deps.df
1062
+ embeddings = deps.embeddings
1063
+
1064
  # Find the query model
1065
  if 'model_id' in df.index.names or df.index.name == 'model_id':
1066
  try:
 
1076
 
1077
  query_embedding = embeddings[model_idx]
1078
 
 
1079
  filtered_df = data_loader.filter_data(
1080
  df=df,
1081
  min_downloads=min_downloads,
 
1085
  pipeline_tags=None
1086
  )
1087
 
 
1088
  if df.index.name == 'model_id' or 'model_id' in df.index.names:
1089
  filtered_indices = [df.index.get_loc(idx) for idx in filtered_df.index]
1090
  filtered_indices = np.array(filtered_indices, dtype=int)
1091
  else:
1092
  filtered_indices = filtered_df.index.values.astype(int)
1093
 
 
1094
  filtered_embeddings = embeddings[filtered_indices]
1095
  from sklearn.metrics.pairwise import cosine_similarity
1096
  query_embedding_2d = query_embedding.reshape(1, -1)
1097
  similarities = cosine_similarity(query_embedding_2d, filtered_embeddings)[0]
1098
 
 
1099
  top_k_local_indices = np.argpartition(similarities, -k)[-k:]
1100
  top_k_local_indices = top_k_local_indices[np.argsort(similarities[top_k_local_indices])][::-1]
1101
 
 
1102
  if reduced_embeddings is None:
1103
  raise HTTPException(status_code=503, detail="Reduced embeddings not ready")
1104
 
 
1105
  top_k_original_indices = filtered_indices[top_k_local_indices]
1106
  top_k_reduced = reduced_embeddings[top_k_original_indices]
1107
 
 
1108
  similar_models = []
1109
  for i, orig_idx in enumerate(top_k_original_indices):
1110
  row = df.iloc[orig_idx]
 
1141
  """
1142
  Calculate distance/similarity between two models.
1143
  """
1144
+ if deps.df is None or deps.embeddings is None:
 
 
1145
  raise HTTPException(status_code=503, detail="Data not loaded")
1146
 
1147
+ df = deps.df
1148
+ embeddings = deps.embeddings
1149
+
1150
  # Find both models - optimized with index lookup
1151
  if 'model_id' in df.index.names or df.index.name == 'model_id':
1152
  try:
 
1183
  Export selected models as JSON with full metadata.
1184
  """
1185
  if df is None:
1186
+ raise DataNotLoadedError()
1187
 
1188
  # Optimized export with index lookup
1189
  if 'model_id' in df.index.names or df.index.name == 'model_id':
 
1198
  if len(exported) == 0:
1199
  return {"models": []}
1200
 
 
1201
  models = [
1202
  {
1203
  "model_id": str(row.get('model_id', '')),
 
1235
  Returns network graph data suitable for visualization.
1236
  """
1237
  if df is None:
1238
+ raise DataNotLoadedError()
1239
 
1240
  try:
1241
  network_builder = ModelNetworkBuilder(df)
 
 
1242
  top_models = network_builder.get_top_models_by_field(
1243
  library=library,
1244
  pipeline_tag=pipeline_tag,
 
1255
  }
1256
 
1257
  model_ids = [mid for mid, _ in top_models]
 
 
1258
  graph = network_builder.build_cooccurrence_network(
1259
  model_ids=model_ids,
1260
  cooccurrence_method=cooccurrence_method
1261
  )
1262
 
 
1263
  nodes = []
1264
  for node_id, attrs in graph.nodes(data=True):
1265
  nodes.append({
 
1287
  "links": links,
1288
  "statistics": stats
1289
  }
1290
+ except (ValueError, KeyError, AttributeError) as e:
1291
+ logger.error(f"Error building network: {e}", exc_info=True)
1292
  raise HTTPException(status_code=500, detail=f"Error building network: {str(e)}")
1293
 
1294
 
1295
  @app.get("/api/network/family/{model_id}")
1296
  async def get_family_network(
1297
  model_id: str,
1298
+ max_depth: Optional[int] = Query(None, ge=1, le=100, description="Maximum depth to traverse. If None, traverses entire tree without limit."),
1299
+ edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
1300
+ include_edge_attributes: bool = Query(True, description="Whether to include edge attributes (change in likes, downloads, etc.)")
1301
  ):
1302
  """
1303
  Build family tree network for a model (directed graph).
1304
+ Returns network graph data showing parent-child relationships with multiple relationship types.
1305
+ Supports filtering by edge type (finetune, quantized, adapter, merge, parent).
1306
  """
1307
  if df is None:
1308
+ raise DataNotLoadedError()
1309
 
1310
  try:
1311
+ filter_types = None
1312
+ if edge_types:
1313
+ filter_types = [t.strip() for t in edge_types.split(',') if t.strip()]
1314
+
1315
  network_builder = ModelNetworkBuilder(df)
1316
  graph = network_builder.build_family_tree_network(
1317
  root_model_id=model_id,
1318
+ max_depth=max_depth,
1319
+ include_edge_attributes=include_edge_attributes,
1320
+ filter_edge_types=filter_types
1321
  )
1322
 
 
1323
  nodes = []
1324
  for node_id, attrs in graph.nodes(data=True):
1325
  nodes.append({
1326
  "id": node_id,
1327
  "title": attrs.get('title', node_id),
1328
+ "freq": attrs.get('freq', 0),
1329
+ "likes": attrs.get('likes', 0),
1330
+ "downloads": attrs.get('downloads', 0),
1331
+ "library": attrs.get('library', ''),
1332
+ "pipeline": attrs.get('pipeline', '')
1333
  })
1334
 
1335
  links = []
1336
+ for source, target, edge_attrs in graph.edges(data=True):
1337
+ link_data = {
1338
  "source": source,
1339
+ "target": target,
1340
+ "edge_type": edge_attrs.get('edge_type'),
1341
+ "edge_types": edge_attrs.get('edge_types', [])
1342
+ }
1343
+
1344
+ if include_edge_attributes:
1345
+ link_data.update({
1346
+ "change_in_likes": edge_attrs.get('change_in_likes'),
1347
+ "percentage_change_in_likes": edge_attrs.get('percentage_change_in_likes'),
1348
+ "change_in_downloads": edge_attrs.get('change_in_downloads'),
1349
+ "percentage_change_in_downloads": edge_attrs.get('percentage_change_in_downloads'),
1350
+ "change_in_createdAt_days": edge_attrs.get('change_in_createdAt_days')
1351
+ })
1352
+
1353
+ links.append(link_data)
1354
 
1355
  stats = network_builder.get_network_statistics(graph)
1356
 
 
1360
  "statistics": stats,
1361
  "root_model": model_id
1362
  }
1363
+ except (ValueError, KeyError, AttributeError) as e:
1364
+ logger.error(f"Error building family network: {e}", exc_info=True)
1365
  raise HTTPException(status_code=500, detail=f"Error building family network: {str(e)}")
1366
 
1367
 
 
1376
  Similar to graph database queries for finding connected nodes.
1377
  """
1378
  if df is None:
1379
+ raise DataNotLoadedError()
1380
 
1381
  try:
1382
  network_builder = ModelNetworkBuilder(df)
 
1383
  top_models = network_builder.get_top_models_by_field(n=1000)
1384
  model_ids = [mid for mid, _ in top_models]
1385
  graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined')
 
1396
  "neighbors": neighbors,
1397
  "count": len(neighbors)
1398
  }
1399
+ except (ValueError, KeyError, AttributeError) as e:
1400
+ logger.error(f"Error finding neighbors: {e}", exc_info=True)
1401
  raise HTTPException(status_code=500, detail=f"Error finding neighbors: {str(e)}")
1402
 
1403
 
 
1412
  Similar to graph database path queries.
1413
  """
1414
  if df is None:
1415
+ raise DataNotLoadedError()
1416
 
1417
  try:
1418
  network_builder = ModelNetworkBuilder(df)
 
1460
  Similar to graph database queries for co-assignment patterns.
1461
  """
1462
  if df is None:
1463
+ raise DataNotLoadedError()
1464
 
1465
  try:
1466
  network_builder = ModelNetworkBuilder(df)
 
1497
  Similar to graph database relationship queries.
1498
  """
1499
  if df is None:
1500
+ raise DataNotLoadedError()
1501
 
1502
  try:
1503
  network_builder = ModelNetworkBuilder(df)
 
1522
  async def get_current_model_count(
1523
  use_cache: bool = Query(True),
1524
  force_refresh: bool = Query(False),
1525
+ use_dataset_snapshot: bool = Query(False),
1526
+ use_models_page: bool = Query(True)
1527
  ):
1528
  """
1529
  Get the current number of models on Hugging Face Hub.
1530
+ Uses multiple strategies: models page scraping (fastest), dataset snapshot, or API.
1531
 
1532
  Query Parameters:
1533
  use_cache: Use cached results if available (default: True)
1534
  force_refresh: Force refresh even if cache is valid (default: False)
1535
+ use_dataset_snapshot: Use dataset snapshot for breakdowns (default: False)
1536
+ use_models_page: Try to get count from HF models page first (default: True)
1537
  """
1538
  try:
1539
+ tracker = get_tracker()
1540
+
1541
  if use_dataset_snapshot:
1542
+ count_data = tracker.get_count_from_models_page()
 
 
1543
  if count_data is None:
1544
+ count_data = tracker.get_current_model_count(use_models_page=False)
1545
+ else:
1546
+ try:
1547
+ from utils.data_loader import ModelDataLoader
1548
+ data_loader = ModelDataLoader()
1549
+ df = data_loader.load_data(sample_size=10000)
1550
+ library_counts = {}
1551
+ pipeline_counts = {}
1552
+
1553
+ for _, row in df.iterrows():
1554
+ if pd.notna(row.get('library_name')):
1555
+ lib = str(row.get('library_name'))
1556
+ library_counts[lib] = library_counts.get(lib, 0) + 1
1557
+ if pd.notna(row.get('pipeline_tag')):
1558
+ pipeline = str(row.get('pipeline_tag'))
1559
+ pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
1560
+
1561
+ if len(df) > 0 and count_data["total_models"] > len(df):
1562
+ scale_factor = count_data["total_models"] / len(df)
1563
+ library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
1564
+ pipeline_counts = {k: int(v * scale_factor) for k, v in pipeline_counts.items()}
1565
+
1566
+ count_data["models_by_library"] = library_counts
1567
+ count_data["models_by_pipeline"] = pipeline_counts
1568
+ except Exception as e:
1569
+ logger.warning(f"Could not get breakdowns from dataset: {e}")
1570
  else:
1571
+ count_data = tracker.get_current_model_count(use_models_page=use_models_page)
 
 
1572
 
1573
  return count_data
1574
  except Exception as e:
1575
+ logger.error(f"Error fetching model count: {e}", exc_info=True)
1576
  raise HTTPException(status_code=500, detail=f"Error fetching model count: {str(e)}")
1577
 
1578
 
 
1593
  try:
1594
  from datetime import datetime
1595
 
1596
+ tracker = get_tracker()
1597
 
1598
  start = None
1599
  end = None
 
1623
  async def get_latest_model_count():
1624
  """Get the most recently recorded model count from database."""
1625
  try:
1626
+ tracker = get_tracker()
1627
  latest = tracker.get_latest_count()
1628
  if latest is None:
1629
  raise HTTPException(status_code=404, detail="No model counts recorded yet")
 
1647
  use_dataset_snapshot: Use dataset snapshot instead of API (faster, default: False)
1648
  """
1649
  try:
1650
+ tracker = get_tracker()
1651
 
 
1652
  def record():
1653
  if use_dataset_snapshot:
1654
  count_data = tracker.get_count_from_dataset_snapshot()
1655
  if count_data:
1656
  tracker.record_count(count_data, source="dataset_snapshot")
1657
  else:
 
1658
  count_data = tracker.get_current_model_count(use_cache=False)
1659
  tracker.record_count(count_data, source="api")
1660
  else:
 
1681
  days: Number of days to analyze
1682
  """
1683
  try:
1684
+ tracker = get_tracker()
1685
  stats = tracker.get_growth_stats(days)
1686
  return stats
1687
  except Exception as e:
 
1703
  Similar to Open Syllabus graph export functionality.
1704
  """
1705
  if df is None:
1706
+ raise DataNotLoadedError()
1707
 
1708
  try:
1709
  network_builder = ModelNetworkBuilder(df)
1710
 
 
1711
  top_models = network_builder.get_top_models_by_field(
1712
  library=library,
1713
  pipeline_tag=pipeline_tag,
 
1720
  raise HTTPException(status_code=404, detail="No models found matching criteria")
1721
 
1722
  model_ids = [mid for mid, _ in top_models]
 
 
1723
  graph = network_builder.build_cooccurrence_network(
1724
  model_ids=model_ids,
1725
  cooccurrence_method=cooccurrence_method
1726
  )
1727
 
 
1728
  with tempfile.NamedTemporaryFile(mode='w', suffix='.graphml', delete=False) as tmp_file:
1729
  tmp_path = tmp_file.name
1730
  network_builder.export_graphml(graph, tmp_path)
1731
 
 
1732
  background_tasks.add_task(os.unlink, tmp_path)
1733
 
 
1734
  return FileResponse(
1735
  tmp_path,
1736
  media_type='application/xml',
1737
  filename=f'network_{cooccurrence_method}_{n}_models.graphml'
1738
  )
1739
+ except (ValueError, KeyError, AttributeError, IOError) as e:
1740
+ logger.error(f"Error exporting network: {e}", exc_info=True)
1741
  raise HTTPException(status_code=500, detail=f"Error exporting network: {str(e)}")
1742
 
1743
 
 
1748
  Extracts arXiv IDs from model tags and fetches paper information.
1749
  """
1750
  if df is None:
1751
+ raise DataNotLoadedError()
1752
 
1753
  model = df[df.get('model_id', '') == model_id]
1754
  if len(model) == 0:
 
1777
  }
1778
 
1779
 
1780
+ @app.get("/api/models/minimal.bin")
1781
+ async def get_minimal_binary():
1782
+ """
1783
+ Serve the binary minimal dataset file.
1784
+ This is optimized for fast client-side loading.
1785
+ """
1786
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1787
+ root_dir = os.path.dirname(backend_dir)
1788
+ binary_path = os.path.join(root_dir, "cache", "binary", "embeddings.bin")
1789
+
1790
+ if not os.path.exists(binary_path):
1791
+ raise HTTPException(status_code=404, detail="Binary dataset not found. Run export_binary.py first.")
1792
+
1793
+ return FileResponse(
1794
+ binary_path,
1795
+ media_type="application/octet-stream",
1796
+ headers={
1797
+ "Content-Disposition": "attachment; filename=embeddings.bin",
1798
+ "Cache-Control": "public, max-age=3600"
1799
+ }
1800
+ )
1801
+
1802
+
1803
+ @app.get("/api/models/model_ids.json")
1804
+ async def get_model_ids_json():
1805
+ """Serve the model IDs JSON file."""
1806
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1807
+ root_dir = os.path.dirname(backend_dir)
1808
+ json_path = os.path.join(root_dir, "cache", "binary", "model_ids.json")
1809
+
1810
+ if not os.path.exists(json_path):
1811
+ raise HTTPException(status_code=404, detail="Model IDs file not found.")
1812
+
1813
+ return FileResponse(
1814
+ json_path,
1815
+ media_type="application/json",
1816
+ headers={"Cache-Control": "public, max-age=3600"}
1817
+ )
1818
+
1819
+
1820
+ @app.get("/api/models/metadata.json")
1821
+ async def get_metadata_json():
1822
+ """Serve the metadata JSON file with lookup tables."""
1823
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1824
+ root_dir = os.path.dirname(backend_dir)
1825
+ json_path = os.path.join(root_dir, "cache", "binary", "metadata.json")
1826
+
1827
+ if not os.path.exists(json_path):
1828
+ raise HTTPException(status_code=404, detail="Metadata file not found.")
1829
+
1830
+ return FileResponse(
1831
+ json_path,
1832
+ media_type="application/json",
1833
+ headers={"Cache-Control": "public, max-age=3600"}
1834
+ )
1835
+
1836
+
1837
  @app.get("/api/model/{model_id}/files")
1838
  async def get_model_files(model_id: str, branch: str = Query("main")):
1839
  """
1840
  Get file tree for a model from Hugging Face.
1841
  Proxies the request to avoid CORS issues.
1842
+ Returns a flat list of files with path and size information.
1843
  """
1844
+ if not model_id or not model_id.strip():
1845
+ raise HTTPException(status_code=400, detail="Invalid model ID")
1846
+
1847
+ branches_to_try = [branch, "main", "master"] if branch not in ["main", "master"] else [branch, "main" if branch == "master" else "master"]
1848
+
1849
  try:
1850
+ async with httpx.AsyncClient(timeout=15.0) as client:
 
 
 
1851
  for branch_name in branches_to_try:
1852
  try:
1853
  url = f"https://huggingface.co/api/models/{model_id}/tree/{branch_name}"
1854
  response = await client.get(url)
1855
+
1856
  if response.status_code == 200:
1857
+ data = response.json()
1858
+ # Ensure we return an array
1859
+ if isinstance(data, list):
1860
+ return data
1861
+ elif isinstance(data, dict) and 'tree' in data:
1862
+ return data['tree']
1863
+ else:
1864
+ return []
1865
+
1866
+ elif response.status_code == 404:
1867
+ # Try next branch
1868
+ continue
1869
+ else:
1870
+ logger.warning(f"Unexpected status {response.status_code} for {url}")
1871
+ continue
1872
+
1873
+ except httpx.HTTPStatusError as e:
1874
+ if e.response.status_code == 404:
1875
+ continue # Try next branch
1876
+ logger.warning(f"HTTP error for branch {branch_name}: {e}")
1877
+ continue
1878
+ except httpx.HTTPError as e:
1879
+ logger.warning(f"HTTP error for branch {branch_name}: {e}")
1880
  continue
1881
 
1882
+ # All branches failed
1883
+ raise HTTPException(
1884
+ status_code=404,
1885
+ detail=f"File tree not found for model '{model_id}'. The model may not exist or may not have any files."
1886
+ )
1887
+
1888
  except httpx.TimeoutException:
1889
+ raise HTTPException(
1890
+ status_code=504,
1891
+ detail="Request to Hugging Face timed out. Please try again later."
1892
+ )
1893
+ except HTTPException:
1894
+ raise # Re-raise HTTP exceptions
1895
  except Exception as e:
1896
+ logger.error(f"Error fetching file tree: {e}", exc_info=True)
1897
+ raise HTTPException(
1898
+ status_code=500,
1899
+ detail=f"Error fetching file tree: {str(e)}"
1900
+ )
1901
 
1902
 
1903
  if __name__ == "__main__":
1904
  import uvicorn
 
1905
  port = int(os.getenv("PORT", 8000))
1906
  uvicorn.run(app, host="0.0.0.0", port=port)
1907
 
backend/api/routes/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ API route modules.
3
+ """
4
+ from . import models, stats, clusters
5
+
6
+ __all__ = ['models', 'stats', 'clusters']
backend/api/routes/clusters.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API routes for cluster endpoints.
3
+ """
4
+ from fastapi import APIRouter
5
+ import numpy as np
6
+ import pandas as pd
7
+ from core.exceptions import DataNotLoadedError
8
+ import api.dependencies as deps
9
+
10
+ router = APIRouter(prefix="/api", tags=["clusters"])
11
+
12
+
13
+ @router.get("/clusters")
14
+ async def get_clusters():
15
+ """Get all clusters with metadata and hierarchical labels."""
16
+ if deps.df is None:
17
+ raise DataNotLoadedError()
18
+
19
+ # Import cluster_labels from models route
20
+ from api.routes.models import cluster_labels
21
+
22
+ # If clusters haven't been computed yet, return empty list instead of error
23
+ # This allows the frontend to work while data is still loading
24
+ if cluster_labels is None:
25
+ return {"clusters": []}
26
+
27
+ df = deps.df
28
+
29
+ # Generate hierarchical labels for clusters
30
+ clusters = []
31
+ unique_clusters = np.unique(cluster_labels)
32
+
33
+ for cluster_id in unique_clusters:
34
+ cluster_mask = cluster_labels == cluster_id
35
+ cluster_models = df[cluster_mask]
36
+
37
+ if len(cluster_models) == 0:
38
+ continue
39
+
40
+ # Generate hierarchical label
41
+ library_counts = cluster_models['library_name'].value_counts()
42
+ pipeline_counts = cluster_models['pipeline_tag'].value_counts()
43
+
44
+ # Determine primary domain/library
45
+ if len(library_counts) > 0:
46
+ primary_lib = library_counts.index[0]
47
+ if primary_lib and pd.notna(primary_lib):
48
+ if 'transformers' in str(primary_lib).lower():
49
+ domain = "NLP"
50
+ elif 'diffusers' in str(primary_lib).lower():
51
+ domain = "Multimodal"
52
+ elif 'timm' in str(primary_lib).lower():
53
+ domain = "Computer Vision"
54
+ else:
55
+ domain = str(primary_lib).replace('_', ' ').title()
56
+ else:
57
+ domain = "Other"
58
+ else:
59
+ domain = "Other"
60
+
61
+ # Determine subdomain from pipeline
62
+ if len(pipeline_counts) > 0:
63
+ primary_pipeline = pipeline_counts.index[0]
64
+ if primary_pipeline and pd.notna(primary_pipeline):
65
+ subdomain = str(primary_pipeline).replace('-', ' ').replace('_', ' ').title()
66
+ else:
67
+ subdomain = "General"
68
+ else:
69
+ subdomain = "General"
70
+
71
+ # Determine characteristics
72
+ characteristics = []
73
+ model_ids_lower = cluster_models['model_id'].astype(str).str.lower()
74
+ if model_ids_lower.str.contains('gpt', na=False).any():
75
+ characteristics.append("GPT-based")
76
+ if cluster_models['parent_model'].notna().any():
77
+ characteristics.append("Fine-tuned")
78
+ if not characteristics:
79
+ characteristics.append("Base Models")
80
+
81
+ char_str = "; ".join(characteristics)
82
+ label = f"{domain} β€” {subdomain} ({char_str})"
83
+
84
+ # Generate color (use consistent colors based on cluster_id)
85
+ colors = [
86
+ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
87
+ "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
88
+ ]
89
+ color = colors[cluster_id % len(colors)]
90
+
91
+ clusters.append({
92
+ "cluster_id": int(cluster_id),
93
+ "cluster_label": label,
94
+ "count": int(len(cluster_models)),
95
+ "color": color
96
+ })
97
+
98
+ # Sort by count descending
99
+ clusters.sort(key=lambda x: x["count"], reverse=True)
100
+
101
+ return {"clusters": clusters}
102
+
backend/api/routes/models.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API routes for model data endpoints.
3
+ """
4
+ from typing import Optional
5
+ from fastapi import APIRouter, Query, HTTPException
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pickle
9
+ import os
10
+ import logging
11
+
12
+ from umap import UMAP
13
+ from models.schemas import ModelPoint
14
+ from utils.family_tree import calculate_family_depths
15
+ from utils.dimensionality_reduction import DimensionReducer
16
+ from core.exceptions import DataNotLoadedError, EmbeddingsNotReadyError
17
+ import api.dependencies as deps
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ router = APIRouter(prefix="/api", tags=["models"])
22
+
23
+ # Global cluster labels cache (shared across routes)
24
+ cluster_labels = None
25
+
26
+
27
+ def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray:
28
+ from sklearn.cluster import KMeans
29
+
30
+ n_samples = len(reduced_embeddings)
31
+ if n_samples < n_clusters:
32
+ n_clusters = max(1, n_samples // 10)
33
+
34
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
35
+ return kmeans.fit_predict(reduced_embeddings)
36
+
37
+
38
+ @router.get("/models")
39
+ async def get_models(
40
+ min_downloads: int = Query(0),
41
+ min_likes: int = Query(0),
42
+ search_query: Optional[str] = Query(None),
43
+ color_by: str = Query("library_name"),
44
+ size_by: str = Query("downloads"),
45
+ max_points: Optional[int] = Query(None),
46
+ projection_method: str = Query("umap"),
47
+ base_models_only: bool = Query(False),
48
+ max_hierarchy_depth: Optional[int] = Query(None, ge=0, description="Filter to models at or below this hierarchy depth."),
49
+ use_graph_embeddings: bool = Query(False, description="Use graph-aware embeddings that respect family tree structure")
50
+ ):
51
+ if deps.df is None:
52
+ raise DataNotLoadedError()
53
+
54
+ df = deps.df
55
+ data_loader = deps.data_loader
56
+
57
+ # Filter data
58
+ filtered_df = data_loader.filter_data(
59
+ df=df,
60
+ min_downloads=min_downloads,
61
+ min_likes=min_likes,
62
+ search_query=search_query,
63
+ libraries=None,
64
+ pipeline_tags=None
65
+ )
66
+
67
+ if base_models_only:
68
+ if 'parent_model' in filtered_df.columns:
69
+ filtered_df = filtered_df[
70
+ filtered_df['parent_model'].isna() |
71
+ (filtered_df['parent_model'].astype(str).str.strip() == '') |
72
+ (filtered_df['parent_model'].astype(str) == 'nan')
73
+ ]
74
+
75
+ if max_hierarchy_depth is not None:
76
+ family_depths = calculate_family_depths(df)
77
+ filtered_df = filtered_df[
78
+ filtered_df['model_id'].astype(str).map(lambda x: family_depths.get(x, 0) <= max_hierarchy_depth)
79
+ ]
80
+
81
+ filtered_count = len(filtered_df)
82
+
83
+ if len(filtered_df) == 0:
84
+ return {
85
+ "models": [],
86
+ "filtered_count": 0,
87
+ "returned_count": 0
88
+ }
89
+
90
+ if max_points is not None and len(filtered_df) > max_points:
91
+ if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
92
+ sampled_dfs = []
93
+ for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
94
+ n_samples = max(1, int(max_points * len(group) / len(filtered_df)))
95
+ sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
96
+ filtered_df = pd.concat(sampled_dfs, ignore_index=True)
97
+ if len(filtered_df) > max_points:
98
+ filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
99
+ else:
100
+ filtered_df = filtered_df.reset_index(drop=True)
101
+ else:
102
+ filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
103
+
104
+ # Determine which embeddings to use
105
+ if use_graph_embeddings and deps.combined_embeddings is not None:
106
+ current_embeddings = deps.combined_embeddings
107
+ current_reduced = deps.reduced_embeddings_graph
108
+ embedding_type = "graph-aware"
109
+ else:
110
+ if deps.embeddings is None:
111
+ raise EmbeddingsNotReadyError()
112
+ current_embeddings = deps.embeddings
113
+ current_reduced = deps.reduced_embeddings
114
+ embedding_type = "text-only"
115
+
116
+ # Handle reduced embeddings loading/generation
117
+ reducer = deps.reducer
118
+ if current_reduced is None or (reducer and reducer.method != projection_method.lower()):
119
+ backend_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
120
+ root_dir = os.path.dirname(backend_dir)
121
+ cache_dir = os.path.join(root_dir, "cache")
122
+ cache_suffix = "_graph" if use_graph_embeddings and deps.combined_embeddings is not None else ""
123
+ reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d{cache_suffix}.pkl")
124
+ reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d{cache_suffix}.pkl")
125
+
126
+ if os.path.exists(reduced_cache) and os.path.exists(reducer_cache):
127
+ try:
128
+ with open(reduced_cache, 'rb') as f:
129
+ current_reduced = pickle.load(f)
130
+ if reducer is None or reducer.method != projection_method.lower():
131
+ reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
132
+ reducer.load_reducer(reducer_cache)
133
+ except (IOError, pickle.UnpicklingError, EOFError) as e:
134
+ logger.warning(f"Failed to load cached reduced embeddings: {e}")
135
+ current_reduced = None
136
+
137
+ if current_reduced is None:
138
+ if reducer is None or reducer.method != projection_method.lower():
139
+ reducer = DimensionReducer(method=projection_method.lower(), n_components=3)
140
+ if projection_method.lower() == "umap":
141
+ reducer.reducer = UMAP(
142
+ n_components=3,
143
+ n_neighbors=30,
144
+ min_dist=0.3,
145
+ metric='cosine',
146
+ random_state=42,
147
+ n_jobs=-1,
148
+ low_memory=True,
149
+ spread=1.5
150
+ )
151
+ current_reduced = reducer.fit_transform(current_embeddings)
152
+ with open(reduced_cache, 'wb') as f:
153
+ pickle.dump(current_reduced, f)
154
+ reducer.save_reducer(reducer_cache)
155
+
156
+ # Update global variable
157
+ if use_graph_embeddings and deps.combined_embeddings is not None:
158
+ deps.reduced_embeddings_graph = current_reduced
159
+ else:
160
+ deps.reduced_embeddings = current_reduced
161
+
162
+ # Get indices for filtered data
163
+ filtered_model_ids = filtered_df['model_id'].astype(str).values
164
+
165
+ if df.index.name == 'model_id' or 'model_id' in df.index.names:
166
+ filtered_indices = []
167
+ for model_id in filtered_model_ids:
168
+ try:
169
+ pos = df.index.get_loc(model_id)
170
+ if isinstance(pos, (int, np.integer)):
171
+ filtered_indices.append(int(pos))
172
+ elif isinstance(pos, (slice, np.ndarray)):
173
+ if isinstance(pos, slice):
174
+ filtered_indices.append(int(pos.start))
175
+ else:
176
+ filtered_indices.append(int(pos[0]))
177
+ except (KeyError, TypeError):
178
+ continue
179
+ filtered_indices = np.array(filtered_indices, dtype=np.int32)
180
+ else:
181
+ df_model_ids = df['model_id'].astype(str).values
182
+ model_id_to_pos = {mid: pos for pos, mid in enumerate(df_model_ids)}
183
+ filtered_indices = np.array([
184
+ model_id_to_pos[mid] for mid in filtered_model_ids
185
+ if mid in model_id_to_pos
186
+ ], dtype=np.int32)
187
+
188
+ if len(filtered_indices) == 0:
189
+ return {
190
+ "models": [],
191
+ "embedding_type": embedding_type,
192
+ "filtered_count": filtered_count,
193
+ "returned_count": 0
194
+ }
195
+
196
+ filtered_reduced = current_reduced[filtered_indices]
197
+ family_depths = calculate_family_depths(df)
198
+
199
+ global cluster_labels
200
+ clustering_embeddings = current_reduced
201
+ if cluster_labels is None or len(cluster_labels) != len(clustering_embeddings):
202
+ cluster_labels = compute_clusters(clustering_embeddings, n_clusters=min(50, len(clustering_embeddings) // 100))
203
+
204
+ filtered_clusters = cluster_labels[filtered_indices]
205
+
206
+ model_ids = filtered_df['model_id'].astype(str).values
207
+ library_names = filtered_df.get('library_name', pd.Series([None] * len(filtered_df))).values
208
+ pipeline_tags = filtered_df.get('pipeline_tag', pd.Series([None] * len(filtered_df))).values
209
+ downloads_arr = filtered_df.get('downloads', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
210
+ likes_arr = filtered_df.get('likes', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values
211
+ trending_scores = filtered_df.get('trendingScore', pd.Series([None] * len(filtered_df))).values
212
+ tags_arr = filtered_df.get('tags', pd.Series([None] * len(filtered_df))).values
213
+ parent_models = filtered_df.get('parent_model', pd.Series([None] * len(filtered_df))).values
214
+ licenses_arr = filtered_df.get('licenses', pd.Series([None] * len(filtered_df))).values
215
+ created_at_arr = filtered_df.get('createdAt', pd.Series([None] * len(filtered_df))).values
216
+
217
+ x_coords = filtered_reduced[:, 0].astype(float)
218
+ y_coords = filtered_reduced[:, 1].astype(float)
219
+ z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float)
220
+ models = [
221
+ ModelPoint(
222
+ model_id=model_ids[idx],
223
+ x=float(x_coords[idx]),
224
+ y=float(y_coords[idx]),
225
+ z=float(z_coords[idx]),
226
+ library_name=library_names[idx] if pd.notna(library_names[idx]) else None,
227
+ pipeline_tag=pipeline_tags[idx] if pd.notna(pipeline_tags[idx]) else None,
228
+ downloads=int(downloads_arr[idx]),
229
+ likes=int(likes_arr[idx]),
230
+ trending_score=float(trending_scores[idx]) if idx < len(trending_scores) and pd.notna(trending_scores[idx]) else None,
231
+ tags=tags_arr[idx] if idx < len(tags_arr) and pd.notna(tags_arr[idx]) else None,
232
+ parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None,
233
+ licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None,
234
+ family_depth=family_depths.get(model_ids[idx], None),
235
+ cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None,
236
+ created_at=str(created_at_arr[idx]) if idx < len(created_at_arr) and pd.notna(created_at_arr[idx]) else None
237
+ )
238
+ for idx in range(len(filtered_df))
239
+ ]
240
+
241
+ return {
242
+ "models": models,
243
+ "embedding_type": embedding_type,
244
+ "filtered_count": filtered_count,
245
+ "returned_count": len(models)
246
+ }
247
+
backend/api/routes/stats.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API routes for statistics endpoints.
3
+ """
4
+ from fastapi import APIRouter
5
+ from core.exceptions import DataNotLoadedError
6
+ import api.dependencies as deps
7
+
8
+ router = APIRouter(prefix="/api", tags=["stats"])
9
+
10
+
11
+ @router.get("/stats")
12
+ async def get_stats():
13
+ """Get dataset statistics."""
14
+ if deps.df is None:
15
+ raise DataNotLoadedError()
16
+
17
+ df = deps.df
18
+ total_models = len(df.index) if hasattr(df, 'index') else len(df)
19
+
20
+ # Get unique licenses with counts
21
+ licenses = {}
22
+ if 'license' in df.columns:
23
+ import pandas as pd
24
+ license_counts = df['license'].value_counts().to_dict()
25
+ licenses = {str(k): int(v) for k, v in license_counts.items() if pd.notna(k) and str(k) != 'nan'}
26
+
27
+ return {
28
+ "total_models": total_models,
29
+ "unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0,
30
+ "unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
31
+ "unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0,
32
+ "unique_licenses": len(licenses),
33
+ "licenses": licenses,
34
+ "avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0,
35
+ "avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0
36
+ }
37
+
backend/config/requirements.txt CHANGED
@@ -11,5 +11,6 @@ huggingface-hub>=0.17.0
11
  schedule>=1.2.0
12
  tqdm>=4.66.0
13
  networkx>=3.0
 
14
  httpx>=0.24.0
15
 
 
11
  schedule>=1.2.0
12
  tqdm>=4.66.0
13
  networkx>=3.0
14
+ node2vec>=0.4.6
15
  httpx>=0.24.0
16
 
backend/core/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Core configuration and utilities."""
2
+
backend/core/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management."""
2
+ import os
3
+ from typing import Optional
4
+
5
+ class Settings:
6
+ """Application settings."""
7
+ FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
8
+ ALLOW_ALL_ORIGINS: bool = os.getenv("ALLOW_ALL_ORIGINS", "True").lower() in ("true", "1", "yes")
9
+ SAMPLE_SIZE: Optional[int] = None
10
+ USE_GRAPH_EMBEDDINGS: bool = os.getenv("USE_GRAPH_EMBEDDINGS", "false").lower() == "true"
11
+ PORT: int = int(os.getenv("PORT", 8000))
12
+
13
+ @classmethod
14
+ def get_sample_size(cls) -> Optional[int]:
15
+ """Get sample size from environment."""
16
+ sample_size_env = os.getenv("SAMPLE_SIZE")
17
+ if sample_size_env:
18
+ sample_size_val = int(sample_size_env)
19
+ return sample_size_val if sample_size_val > 0 else None
20
+ return None
21
+
22
+ settings = Settings()
23
+
backend/core/exceptions.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom exceptions."""
2
+ from fastapi import HTTPException
3
+
4
+ class ModelNotFoundError(HTTPException):
5
+ """Model not found exception."""
6
+ def __init__(self, model_id: str):
7
+ super().__init__(status_code=404, detail=f"Model not found: {model_id}")
8
+
9
+ class DataNotLoadedError(HTTPException):
10
+ """Data not loaded exception."""
11
+ def __init__(self):
12
+ super().__init__(status_code=503, detail="Data not loaded")
13
+
14
+ class EmbeddingsNotReadyError(HTTPException):
15
+ """Embeddings not ready exception."""
16
+ def __init__(self):
17
+ super().__init__(status_code=503, detail="Embeddings not ready")
18
+
backend/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Data models and schemas."""
2
+
backend/models/schemas.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for API."""
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
+
5
+ class ModelPoint(BaseModel):
6
+ """Model point in 3D space."""
7
+ model_id: str
8
+ x: float
9
+ y: float
10
+ z: float
11
+ library_name: Optional[str]
12
+ pipeline_tag: Optional[str]
13
+ downloads: int
14
+ likes: int
15
+ trending_score: Optional[float]
16
+ tags: Optional[str]
17
+ parent_model: Optional[str] = None
18
+ licenses: Optional[str] = None
19
+ family_depth: Optional[int] = None
20
+ cluster_id: Optional[int] = None
21
+ created_at: Optional[str] = None # ISO format date string
22
+
backend/scripts/export_binary.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Export minimal dataset to binary format for fast client-side loading.
3
+ This creates a compact binary representation optimized for WebGL rendering.
4
+ """
5
+ import struct
6
+ import json
7
+ import numpy as np
8
+ import pandas as pd
9
+ from pathlib import Path
10
+ import sys
11
+ import os
12
+
13
+ # Add parent directory to path
14
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+
16
+ from utils.data_loader import ModelDataLoader
17
+ from utils.dimensionality_reduction import DimensionReducer
18
+ from utils.embeddings import ModelEmbedder
19
+
20
+
21
+ def calculate_family_depths(df: pd.DataFrame) -> dict:
22
+ """Calculate depth of each model in its family tree."""
23
+ depths = {}
24
+
25
+ def get_depth(model_id: str, visited: set = None) -> int:
26
+ if visited is None:
27
+ visited = set()
28
+ if model_id in visited:
29
+ return 0 # Cycle detected
30
+ visited.add(model_id)
31
+
32
+ if model_id in depths:
33
+ return depths[model_id]
34
+
35
+ parent_col = df.get('parent_model', pd.Series([None] * len(df), index=df.index))
36
+ model_row = df[df['model_id'] == model_id]
37
+
38
+ if model_row.empty:
39
+ depths[model_id] = 0
40
+ return 0
41
+
42
+ parent = model_row.iloc[0].get('parent_model')
43
+ if pd.isna(parent) or parent == '' or str(parent) == 'nan':
44
+ depths[model_id] = 0
45
+ return 0
46
+
47
+ parent_depth = get_depth(str(parent), visited.copy())
48
+ depth = parent_depth + 1
49
+ depths[model_id] = depth
50
+ return depth
51
+
52
+ for model_id in df['model_id'].unique():
53
+ if model_id not in depths:
54
+ get_depth(str(model_id))
55
+
56
+ return depths
57
+
58
+
59
+ def export_binary_dataset(df: pd.DataFrame, reduced_embeddings: np.ndarray, output_dir: Path):
60
+ """
61
+ Export minimal dataset to binary format for fast client-side loading.
62
+
63
+ Binary format:
64
+ - Header (64 bytes): magic, version, counts, lookup table sizes
65
+ - Domain lookup table (32 bytes per domain)
66
+ - License lookup table (32 bytes per license)
67
+ - Family lookup table (32 bytes per family)
68
+ - Model records (16 bytes each): x, y, z, domain_id, license_id, family_id, flags
69
+ """
70
+ output_dir.mkdir(parents=True, exist_ok=True)
71
+
72
+ print(f"Exporting {len(df)} models to binary format...")
73
+
74
+ # Ensure we have coordinates
75
+ if 'x' not in df.columns or 'y' not in df.columns:
76
+ if reduced_embeddings is None or len(reduced_embeddings) != len(df):
77
+ raise ValueError("Need reduced embeddings to generate coordinates")
78
+
79
+ df['x'] = reduced_embeddings[:, 0] if reduced_embeddings.shape[1] > 0 else 0.0
80
+ df['y'] = reduced_embeddings[:, 1] if reduced_embeddings.shape[1] > 1 else 0.0
81
+ df['z'] = reduced_embeddings[:, 2] if reduced_embeddings.shape[1] > 2 else 0.0
82
+
83
+ # Create lookup tables
84
+ # Domain = library_name
85
+ domains = sorted(df['library_name'].dropna().astype(str).unique())
86
+ domains = [d for d in domains if d and d != 'nan'][:255] # Limit to 255
87
+
88
+ # License
89
+ licenses = sorted(df['license'].dropna().astype(str).unique())
90
+ licenses = [l for l in licenses if l and l != 'nan'][:255] # Limit to 255
91
+
92
+ # Family ID mapping (use parent_model to create family groups)
93
+ family_depths = calculate_family_depths(df)
94
+
95
+ # Create family mapping: group models by root parent
96
+ def get_root_parent(model_id: str) -> str:
97
+ visited = set()
98
+ current = str(model_id)
99
+ while current in visited == False:
100
+ visited.add(current)
101
+ model_row = df[df['model_id'] == current]
102
+ if model_row.empty:
103
+ return current
104
+ parent = model_row.iloc[0].get('parent_model')
105
+ if pd.isna(parent) or parent == '' or str(parent) == 'nan':
106
+ return current
107
+ current = str(parent)
108
+ return current
109
+
110
+ root_parents = {}
111
+ family_counter = 0
112
+ for model_id in df['model_id'].unique():
113
+ root = get_root_parent(str(model_id))
114
+ if root not in root_parents:
115
+ root_parents[root] = family_counter
116
+ family_counter += 1
117
+
118
+ # Map each model to its family
119
+ model_to_family = {}
120
+ for model_id in df['model_id'].unique():
121
+ root = get_root_parent(str(model_id))
122
+ model_to_family[str(model_id)] = root_parents.get(root, 65535)
123
+
124
+ # Limit families to 65535 (u16 max)
125
+ if len(root_parents) > 65535:
126
+ # Use hash-based family IDs
127
+ import hashlib
128
+ for model_id in df['model_id'].unique():
129
+ root = get_root_parent(str(model_id))
130
+ family_hash = int(hashlib.md5(root.encode()).hexdigest()[:4], 16) % 65535
131
+ model_to_family[str(model_id)] = family_hash
132
+
133
+ # Prepare model records
134
+ records = []
135
+ model_ids = []
136
+
137
+ for idx, row in df.iterrows():
138
+ model_id = str(row['model_id'])
139
+ model_ids.append(model_id)
140
+
141
+ # Get coordinates
142
+ x = float(row.get('x', 0.0))
143
+ y = float(row.get('y', 0.0))
144
+ z = float(row.get('z', 0.0))
145
+
146
+ # Encode domain (library_name)
147
+ domain_str = str(row.get('library_name', ''))
148
+ domain_id = domains.index(domain_str) if domain_str in domains else 255
149
+
150
+ # Encode license
151
+ license_str = str(row.get('license', ''))
152
+ license_id = licenses.index(license_str) if license_str in licenses else 255
153
+
154
+ # Encode family
155
+ family_id = model_to_family.get(model_id, 65535)
156
+
157
+ # Encode flags
158
+ flags = 0
159
+ parent = row.get('parent_model')
160
+ if pd.isna(parent) or parent == '' or str(parent) == 'nan':
161
+ flags |= 0x01 # is_base_model
162
+
163
+ # Check if has children (simple check - could be improved)
164
+ children = df[df['parent_model'] == model_id]
165
+ if len(children) > 0:
166
+ flags |= 0x04 # has_children
167
+ elif not pd.isna(parent) and parent != '' and str(parent) != 'nan':
168
+ flags |= 0x02 # has_parent
169
+
170
+ # Pack record: f32 x, f32 y, f32 z, u8 domain, u8 license, u16 family, u8 flags
171
+ records.append(struct.pack('fffBBBH', x, y, z, domain_id, license_id, family_id, flags))
172
+
173
+ num_models = len(records)
174
+
175
+ # Write binary file
176
+ with open(output_dir / 'embeddings.bin', 'wb') as f:
177
+ # Header (64 bytes)
178
+ header = struct.pack('5sBIIIBBH50s',
179
+ b'HFVIZ', # magic (5 bytes)
180
+ 1, # version (1 byte)
181
+ num_models, # num_models (4 bytes)
182
+ len(domains), # num_domains (4 bytes)
183
+ len(licenses), # num_licenses (4 bytes)
184
+ len(set(model_to_family.values())), # num_families (4 bytes)
185
+ 0, # reserved (1 byte)
186
+ 0, # reserved (1 byte)
187
+ 0, # reserved (2 bytes)
188
+ b'\x00' * 50 # padding (50 bytes)
189
+ )
190
+ f.write(header)
191
+
192
+ # Domain lookup table (32 bytes per domain, null-terminated)
193
+ for domain in domains:
194
+ domain_bytes = domain.encode('utf-8')[:31]
195
+ f.write(domain_bytes.ljust(32, b'\x00'))
196
+
197
+ # License lookup table (32 bytes per license)
198
+ for license in licenses:
199
+ license_bytes = license.encode('utf-8')[:31]
200
+ f.write(license_bytes.ljust(32, b'\x00'))
201
+
202
+ # Model records
203
+ f.write(b''.join(records))
204
+
205
+ # Write model IDs JSON (separate file for string table)
206
+ with open(output_dir / 'model_ids.json', 'w') as f:
207
+ json.dump(model_ids, f)
208
+
209
+ # Write metadata JSON
210
+ metadata = {
211
+ 'domains': domains,
212
+ 'licenses': licenses,
213
+ 'num_models': num_models,
214
+ 'num_families': len(set(model_to_family.values())),
215
+ 'version': 1
216
+ }
217
+ with open(output_dir / 'metadata.json', 'w') as f:
218
+ json.dump(metadata, f, indent=2)
219
+
220
+ binary_size = (output_dir / 'embeddings.bin').stat().st_size
221
+ json_size = (output_dir / 'model_ids.json').stat().st_size
222
+
223
+ print(f"βœ“ Exported {num_models} models")
224
+ print(f"βœ“ Binary size: {binary_size / 1024 / 1024:.2f} MB")
225
+ print(f"βœ“ Model IDs JSON: {json_size / 1024 / 1024:.2f} MB")
226
+ print(f"βœ“ Total: {(binary_size + json_size) / 1024 / 1024:.2f} MB")
227
+ print(f"βœ“ Domains: {len(domains)}")
228
+ print(f"βœ“ Licenses: {len(licenses)}")
229
+ print(f"βœ“ Families: {len(set(model_to_family.values()))}")
230
+
231
+
232
+ if __name__ == '__main__':
233
+ import argparse
234
+
235
+ parser = argparse.ArgumentParser(description='Export dataset to binary format')
236
+ parser.add_argument('--output', type=str, default='backend/cache/binary', help='Output directory')
237
+ parser.add_argument('--sample-size', type=int, default=None, help='Sample size (for testing)')
238
+ args = parser.parse_args()
239
+
240
+ output_dir = Path(args.output)
241
+
242
+ # Load data
243
+ print("Loading dataset...")
244
+ data_loader = ModelDataLoader()
245
+ df = data_loader.load_data(sample_size=args.sample_size)
246
+ df = data_loader.preprocess_for_embedding(df)
247
+
248
+ # Generate embeddings and reduce dimensions if needed
249
+ if 'x' not in df.columns or 'y' not in df.columns:
250
+ print("Generating embeddings...")
251
+ embedder = ModelEmbedder()
252
+ embeddings = embedder.generate_embeddings(df['combined_text'].tolist())
253
+
254
+ print("Reducing dimensions...")
255
+ reducer = DimensionReducer()
256
+ reduced_embeddings = reducer.reduce_dimensions(embeddings, n_components=3, method='umap')
257
+ else:
258
+ reduced_embeddings = None
259
+
260
+ # Export
261
+ export_binary_dataset(df, reduced_embeddings, output_dir)
262
+ print("Done!")
263
+
backend/services/model_tracker.py CHANGED
@@ -5,11 +5,16 @@ Tracks the number of models over time and provides historical data.
5
  import os
6
  import json
7
  import sqlite3
 
 
8
  from datetime import datetime, timedelta
9
  from typing import Dict, List, Optional, Tuple
10
  from huggingface_hub import HfApi
11
  import pandas as pd
12
  from pathlib import Path
 
 
 
13
 
14
 
15
  class ModelCountTracker:
@@ -34,7 +39,6 @@ class ModelCountTracker:
34
  conn = sqlite3.connect(self.db_path)
35
  cursor = conn.cursor()
36
 
37
- # Create table for model counts
38
  cursor.execute("""
39
  CREATE TABLE IF NOT EXISTS model_counts (
40
  id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -47,7 +51,6 @@ class ModelCountTracker:
47
  )
48
  """)
49
 
50
- # Create index for faster queries
51
  cursor.execute("""
52
  CREATE INDEX IF NOT EXISTS idx_timestamp
53
  ON model_counts(timestamp)
@@ -56,27 +59,90 @@ class ModelCountTracker:
56
  conn.commit()
57
  conn.close()
58
 
59
- def get_current_model_count(self) -> Dict:
60
  """
61
- Fetch current model count from Hugging Face Hub API.
62
- Uses efficient pagination to get accurate count.
 
63
 
64
  Returns:
65
- Dictionary with total count and breakdowns
66
  """
67
  try:
68
- # Use pagination to efficiently count models
69
- # The API returns paginated results, so we iterate through pages
70
- # For large counts, we sample and extrapolate for speed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  total_count = 0
73
  library_counts = {}
74
  pipeline_counts = {}
75
- page_size = 1000 # Process in batches
76
- max_pages = 100 # Limit to prevent timeout (can adjust)
77
- sample_size = 10000 # Sample size for breakdowns
78
 
79
- # Count total models efficiently
80
  models_iter = self.api.list_models(full=False)
81
  sampled_models = []
82
 
@@ -87,25 +153,18 @@ class ModelCountTracker:
87
  if i < sample_size:
88
  sampled_models.append(model)
89
 
90
- # Safety limit to prevent infinite loops
91
  if i >= max_pages * page_size:
92
- # If we hit the limit, estimate total from sample
93
- # This is a rough estimate - for exact count, increase max_pages
94
  break
95
 
96
- # Calculate breakdowns from sample (extrapolate if needed)
97
  for model in sampled_models:
98
- # Count by library
99
  if hasattr(model, 'library_name') and model.library_name:
100
  lib = model.library_name
101
  library_counts[lib] = library_counts.get(lib, 0) + 1
102
 
103
- # Count by pipeline
104
  if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
105
  pipeline = model.pipeline_tag
106
  pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
107
 
108
- # If we sampled, scale up the breakdowns proportionally
109
  if len(sampled_models) < total_count and len(sampled_models) > 0:
110
  scale_factor = total_count / len(sampled_models)
111
  library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
@@ -118,7 +177,7 @@ class ModelCountTracker:
118
  "timestamp": datetime.utcnow().isoformat()
119
  }
120
  except Exception as e:
121
- print(f"Error fetching model count: {e}")
122
  return {
123
  "total_models": 0,
124
  "models_by_library": {},
@@ -162,7 +221,7 @@ class ModelCountTracker:
162
  conn.close()
163
  return True
164
  except Exception as e:
165
- print(f"Error recording count: {e}")
166
  return False
167
 
168
  def get_historical_counts(
@@ -211,7 +270,7 @@ class ModelCountTracker:
211
  conn.close()
212
  return results
213
  except Exception as e:
214
- print(f"Error fetching historical counts: {e}")
215
  return []
216
 
217
  def get_latest_count(self) -> Optional[Dict]:
@@ -239,7 +298,7 @@ class ModelCountTracker:
239
  }
240
  return None
241
  except Exception as e:
242
- print(f"Error fetching latest count: {e}")
243
  return None
244
 
245
  def get_growth_stats(self, days: int = 7) -> Dict:
 
5
  import os
6
  import json
7
  import sqlite3
8
+ import logging
9
+ import re
10
  from datetime import datetime, timedelta
11
  from typing import Dict, List, Optional, Tuple
12
  from huggingface_hub import HfApi
13
  import pandas as pd
14
  from pathlib import Path
15
+ import httpx
16
+
17
+ logger = logging.getLogger(__name__)
18
 
19
 
20
  class ModelCountTracker:
 
39
  conn = sqlite3.connect(self.db_path)
40
  cursor = conn.cursor()
41
 
 
42
  cursor.execute("""
43
  CREATE TABLE IF NOT EXISTS model_counts (
44
  id INTEGER PRIMARY KEY AUTOINCREMENT,
 
51
  )
52
  """)
53
 
 
54
  cursor.execute("""
55
  CREATE INDEX IF NOT EXISTS idx_timestamp
56
  ON model_counts(timestamp)
 
59
  conn.commit()
60
  conn.close()
61
 
62
+ def get_count_from_models_page(self) -> Optional[Dict]:
63
  """
64
+ Get model count by scraping the Hugging Face models page.
65
+ Extracts count from the div with class "font-normal text-gray-400" on https://huggingface.co/models
66
+ or from window.__hf_deferred["numTotalItems"] in the page script.
67
 
68
  Returns:
69
+ Dictionary with total_models count, or None if extraction fails
70
  """
71
  try:
72
+ url = "https://huggingface.co/models"
73
+ response = httpx.get(url, timeout=10.0, follow_redirects=True)
74
+ response.raise_for_status()
75
+
76
+ html_content = response.text
77
+
78
+ deferred_pattern = r'window\.__hf_deferred\["numTotalItems"\]\s*=\s*(\d+);'
79
+ deferred_matches = re.findall(deferred_pattern, html_content)
80
+
81
+ if deferred_matches:
82
+ total_models = int(deferred_matches[0])
83
+ logger.info(f"Extracted model count from window.__hf_deferred: {total_models}")
84
+
85
+ return {
86
+ "total_models": total_models,
87
+ "timestamp": datetime.utcnow().isoformat(),
88
+ "source": "hf_models_page",
89
+ "models_by_library": {},
90
+ "models_by_pipeline": {},
91
+ "models_by_author": {}
92
+ }
93
 
94
+ pattern = r'<div[^>]*class="[^"]*font-normal[^"]*text-gray-400[^"]*"[^>]*>([\d,]+)</div>'
95
+ matches = re.findall(pattern, html_content)
96
+
97
+ if matches:
98
+ count_str = matches[0].replace(',', '')
99
+ total_models = int(count_str)
100
+
101
+ logger.info(f"Extracted model count from div: {total_models}")
102
+
103
+ return {
104
+ "total_models": total_models,
105
+ "timestamp": datetime.utcnow().isoformat(),
106
+ "source": "hf_models_page",
107
+ "models_by_library": {},
108
+ "models_by_pipeline": {},
109
+ "models_by_author": {}
110
+ }
111
+
112
+ logger.warning("Could not find model count in HF models page HTML")
113
+ return None
114
+
115
+ except httpx.HTTPError as e:
116
+ logger.error(f"HTTP error fetching HF models page: {e}", exc_info=True)
117
+ return None
118
+ except Exception as e:
119
+ logger.error(f"Error extracting count from HF models page: {e}", exc_info=True)
120
+ return None
121
+
122
+ def get_current_model_count(self, use_models_page: bool = True) -> Dict:
123
+ """
124
+ Fetch current model count from Hugging Face Hub.
125
+ Uses multiple strategies: models page scraping (fastest), then API enumeration.
126
+
127
+ Args:
128
+ use_models_page: Try to get count from HF models page first (default: True)
129
+
130
+ Returns:
131
+ Dictionary with total count and breakdowns
132
+ """
133
+ if use_models_page:
134
+ page_count = self.get_count_from_models_page()
135
+ if page_count:
136
+ return page_count
137
+
138
+ try:
139
  total_count = 0
140
  library_counts = {}
141
  pipeline_counts = {}
142
+ page_size = 1000
143
+ max_pages = 100
144
+ sample_size = 10000
145
 
 
146
  models_iter = self.api.list_models(full=False)
147
  sampled_models = []
148
 
 
153
  if i < sample_size:
154
  sampled_models.append(model)
155
 
 
156
  if i >= max_pages * page_size:
 
 
157
  break
158
 
 
159
  for model in sampled_models:
 
160
  if hasattr(model, 'library_name') and model.library_name:
161
  lib = model.library_name
162
  library_counts[lib] = library_counts.get(lib, 0) + 1
163
 
 
164
  if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
165
  pipeline = model.pipeline_tag
166
  pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
167
 
 
168
  if len(sampled_models) < total_count and len(sampled_models) > 0:
169
  scale_factor = total_count / len(sampled_models)
170
  library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
 
177
  "timestamp": datetime.utcnow().isoformat()
178
  }
179
  except Exception as e:
180
+ logger.error(f"Error fetching model count: {e}", exc_info=True)
181
  return {
182
  "total_models": 0,
183
  "models_by_library": {},
 
221
  conn.close()
222
  return True
223
  except Exception as e:
224
+ logger.error(f"Error recording count: {e}", exc_info=True)
225
  return False
226
 
227
  def get_historical_counts(
 
270
  conn.close()
271
  return results
272
  except Exception as e:
273
+ logger.error(f"Error fetching historical counts: {e}", exc_info=True)
274
  return []
275
 
276
  def get_latest_count(self) -> Optional[Dict]:
 
298
  }
299
  return None
300
  except Exception as e:
301
+ logger.error(f"Error fetching latest count: {e}", exc_info=True)
302
  return None
303
 
304
  def get_growth_stats(self, days: int = 7) -> Dict:
backend/services/model_tracker_improved.py CHANGED
@@ -11,12 +11,17 @@ Key improvements:
11
  import os
12
  import json
13
  import sqlite3
 
 
14
  from datetime import datetime, timedelta
15
  from typing import Dict, List, Optional, Tuple
16
  from huggingface_hub import HfApi
17
  import pandas as pd
18
  from pathlib import Path
19
  import time
 
 
 
20
 
21
 
22
  class ImprovedModelCountTracker:
@@ -78,72 +83,73 @@ class ImprovedModelCountTracker:
78
  elapsed = (datetime.utcnow() - self._cache_timestamp).total_seconds()
79
  return elapsed < self.cache_ttl
80
 
81
- def get_current_model_count(self, use_cache: bool = True, force_refresh: bool = False) -> Dict:
82
  """
83
- Fetch current model count from Hugging Face Hub API.
84
- Uses caching and efficient sampling strategies.
85
 
86
  Args:
87
  use_cache: Whether to use cached results if available
88
  force_refresh: Force refresh even if cache is valid
 
89
 
90
  Returns:
91
  Dictionary with total count and breakdowns
92
  """
93
- # Check cache first
94
  if use_cache and not force_refresh and self._is_cache_valid():
95
  return self._cache
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  try:
98
- # Strategy 1: Try to get count efficiently using pagination
99
- # The HfApi.list_models() returns an iterator, so we can count efficiently
100
  total_count = 0
101
  library_counts = {}
102
  pipeline_counts = {}
103
  author_counts = {}
104
 
105
- # For breakdowns, we sample a subset for efficiency
106
- sample_size = 20000 # Sample 20K models for breakdowns
107
- max_count_for_full_breakdown = 50000 # If less than this, do full breakdown
108
 
109
  models_iter = self.api.list_models(full=False, sort="created", direction=-1)
110
  sampled_models = []
111
 
112
  start_time = time.time()
113
- timeout_seconds = 30 # Don't spend more than 30 seconds
114
 
115
  for i, model in enumerate(models_iter):
116
- # Check timeout
117
  if time.time() - start_time > timeout_seconds:
118
- # If we hit timeout, use sampling strategy
119
  break
120
 
121
  total_count += 1
122
 
123
- # Sample models for breakdowns
124
  if i < sample_size:
125
  sampled_models.append(model)
126
 
127
- # For smaller datasets, we can do full breakdown
128
  if total_count < max_count_for_full_breakdown:
129
- # Count by library
130
  if hasattr(model, 'library_name') and model.library_name:
131
  lib = model.library_name
132
  library_counts[lib] = library_counts.get(lib, 0) + 1
133
 
134
- # Count by pipeline
135
  if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
136
  pipeline = model.pipeline_tag
137
  pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
138
 
139
- # Count by author (extract from model_id)
140
  if hasattr(model, 'id') and model.id:
141
  author = model.id.split('/')[0] if '/' in model.id else 'unknown'
142
  author_counts[author] = author_counts.get(author, 0) + 1
143
 
144
- # If we sampled, calculate breakdowns from sample and extrapolate
145
  if total_count > len(sampled_models) and len(sampled_models) > 0:
146
- # Calculate breakdowns from sample
147
  for model in sampled_models:
148
  if hasattr(model, 'library_name') and model.library_name:
149
  lib = model.library_name
@@ -157,7 +163,6 @@ class ImprovedModelCountTracker:
157
  author = model.id.split('/')[0] if '/' in model.id else 'unknown'
158
  author_counts[author] = author_counts.get(author, 0) + 1
159
 
160
- # Scale up breakdowns proportionally
161
  if len(sampled_models) > 0:
162
  scale_factor = total_count / len(sampled_models)
163
  library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
@@ -168,20 +173,19 @@ class ImprovedModelCountTracker:
168
  "total_models": total_count,
169
  "models_by_library": library_counts,
170
  "models_by_pipeline": pipeline_counts,
171
- "models_by_author": dict(sorted(author_counts.items(), key=lambda x: x[1], reverse=True)[:20]), # Top 20 authors
172
  "timestamp": datetime.utcnow().isoformat(),
173
  "sampling_used": total_count > len(sampled_models) if sampled_models else False,
174
  "sample_size": len(sampled_models) if sampled_models else total_count
175
  }
176
 
177
- # Update cache
178
  self._cache = result
179
  self._cache_timestamp = datetime.utcnow()
180
 
181
  return result
182
 
183
  except Exception as e:
184
- print(f"Error fetching model count: {e}")
185
  return {
186
  "total_models": 0,
187
  "models_by_library": {},
@@ -191,6 +195,70 @@ class ImprovedModelCountTracker:
191
  "error": str(e)
192
  }
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def get_count_from_dataset_snapshot(self, dataset_name: str = "modelbiome/ai_ecosystem_withmodelcards") -> Optional[Dict]:
195
  """
196
  Alternative method: Get count from dataset snapshot (like ai-ecosystem repo does).
@@ -205,11 +273,9 @@ class ImprovedModelCountTracker:
205
  try:
206
  from datasets import load_dataset
207
 
208
- # Load just metadata to get count quickly
209
  dataset = load_dataset(dataset_name, split="train")
210
  total_count = len(dataset)
211
 
212
- # Sample for breakdowns
213
  sample_size = min(10000, total_count)
214
  sample = dataset.shuffle(seed=42).select(range(sample_size))
215
 
@@ -225,7 +291,6 @@ class ImprovedModelCountTracker:
225
  pipeline = item['pipeline_tag']
226
  pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
227
 
228
- # Scale up
229
  if sample_size < total_count:
230
  scale_factor = total_count / sample_size
231
  library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
@@ -239,7 +304,7 @@ class ImprovedModelCountTracker:
239
  "source": "dataset_snapshot"
240
  }
241
  except Exception as e:
242
- print(f"Error loading from dataset snapshot: {e}")
243
  return None
244
 
245
  def record_count(self, count_data: Optional[Dict] = None, source: str = "api") -> bool:
@@ -279,7 +344,7 @@ class ImprovedModelCountTracker:
279
  conn.close()
280
  return True
281
  except Exception as e:
282
- print(f"Error recording count: {e}")
283
  return False
284
 
285
  def get_historical_counts(
@@ -329,7 +394,7 @@ class ImprovedModelCountTracker:
329
  conn.close()
330
  return results
331
  except Exception as e:
332
- print(f"Error fetching historical counts: {e}")
333
  return []
334
 
335
  def get_latest_count(self) -> Optional[Dict]:
@@ -358,7 +423,7 @@ class ImprovedModelCountTracker:
358
  }
359
  return None
360
  except Exception as e:
361
- print(f"Error fetching latest count: {e}")
362
  return None
363
 
364
  def get_growth_stats(self, days: int = 7) -> Dict:
 
11
  import os
12
  import json
13
  import sqlite3
14
+ import logging
15
+ import re
16
  from datetime import datetime, timedelta
17
  from typing import Dict, List, Optional, Tuple
18
  from huggingface_hub import HfApi
19
  import pandas as pd
20
  from pathlib import Path
21
  import time
22
+ import httpx
23
+
24
+ logger = logging.getLogger(__name__)
25
 
26
 
27
  class ImprovedModelCountTracker:
 
83
  elapsed = (datetime.utcnow() - self._cache_timestamp).total_seconds()
84
  return elapsed < self.cache_ttl
85
 
86
+ def get_current_model_count(self, use_cache: bool = True, force_refresh: bool = False, use_models_page: bool = True) -> Dict:
87
  """
88
+ Fetch current model count from Hugging Face Hub.
89
+ Uses multiple strategies: models page scraping (fastest), API, or dataset snapshot.
90
 
91
  Args:
92
  use_cache: Whether to use cached results if available
93
  force_refresh: Force refresh even if cache is valid
94
+ use_models_page: Try to get count from HF models page first (default: True)
95
 
96
  Returns:
97
  Dictionary with total count and breakdowns
98
  """
 
99
  if use_cache and not force_refresh and self._is_cache_valid():
100
  return self._cache
101
 
102
+ if use_models_page:
103
+ page_count = self.get_count_from_models_page()
104
+ if page_count:
105
+ dataset_count = self.get_count_from_dataset_snapshot()
106
+ if dataset_count and dataset_count.get("models_by_library"):
107
+ page_count["models_by_library"] = dataset_count.get("models_by_library", {})
108
+ page_count["models_by_pipeline"] = dataset_count.get("models_by_pipeline", {})
109
+ page_count["models_by_author"] = dataset_count.get("models_by_author", {})
110
+
111
+ self._cache = page_count
112
+ self._cache_timestamp = datetime.utcnow()
113
+ return page_count
114
+
115
  try:
 
 
116
  total_count = 0
117
  library_counts = {}
118
  pipeline_counts = {}
119
  author_counts = {}
120
 
121
+ sample_size = 20000
122
+ max_count_for_full_breakdown = 50000
 
123
 
124
  models_iter = self.api.list_models(full=False, sort="created", direction=-1)
125
  sampled_models = []
126
 
127
  start_time = time.time()
128
+ timeout_seconds = 30
129
 
130
  for i, model in enumerate(models_iter):
 
131
  if time.time() - start_time > timeout_seconds:
 
132
  break
133
 
134
  total_count += 1
135
 
 
136
  if i < sample_size:
137
  sampled_models.append(model)
138
 
 
139
  if total_count < max_count_for_full_breakdown:
 
140
  if hasattr(model, 'library_name') and model.library_name:
141
  lib = model.library_name
142
  library_counts[lib] = library_counts.get(lib, 0) + 1
143
 
 
144
  if hasattr(model, 'pipeline_tag') and model.pipeline_tag:
145
  pipeline = model.pipeline_tag
146
  pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
147
 
 
148
  if hasattr(model, 'id') and model.id:
149
  author = model.id.split('/')[0] if '/' in model.id else 'unknown'
150
  author_counts[author] = author_counts.get(author, 0) + 1
151
 
 
152
  if total_count > len(sampled_models) and len(sampled_models) > 0:
 
153
  for model in sampled_models:
154
  if hasattr(model, 'library_name') and model.library_name:
155
  lib = model.library_name
 
163
  author = model.id.split('/')[0] if '/' in model.id else 'unknown'
164
  author_counts[author] = author_counts.get(author, 0) + 1
165
 
 
166
  if len(sampled_models) > 0:
167
  scale_factor = total_count / len(sampled_models)
168
  library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
 
173
  "total_models": total_count,
174
  "models_by_library": library_counts,
175
  "models_by_pipeline": pipeline_counts,
176
+ "models_by_author": dict(sorted(author_counts.items(), key=lambda x: x[1], reverse=True)[:20]),
177
  "timestamp": datetime.utcnow().isoformat(),
178
  "sampling_used": total_count > len(sampled_models) if sampled_models else False,
179
  "sample_size": len(sampled_models) if sampled_models else total_count
180
  }
181
 
 
182
  self._cache = result
183
  self._cache_timestamp = datetime.utcnow()
184
 
185
  return result
186
 
187
  except Exception as e:
188
+ logger.error(f"Error fetching model count: {e}", exc_info=True)
189
  return {
190
  "total_models": 0,
191
  "models_by_library": {},
 
195
  "error": str(e)
196
  }
197
 
198
+ def get_count_from_models_page(self) -> Optional[Dict]:
199
+ """
200
+ Get model count by scraping the Hugging Face models page.
201
+ Extracts count from the div with class "font-normal text-gray-400" on https://huggingface.co/models
202
+
203
+ Returns:
204
+ Dictionary with total_models count, or None if extraction fails
205
+ """
206
+ try:
207
+ url = "https://huggingface.co/models"
208
+ response = httpx.get(url, timeout=10.0, follow_redirects=True)
209
+ response.raise_for_status()
210
+
211
+ html_content = response.text
212
+
213
+ # Look for the pattern: <div class="font-normal text-gray-400">2,249,310</div>
214
+ # The number is in the format with commas
215
+ pattern = r'<div[^>]*class="[^"]*font-normal[^"]*text-gray-400[^"]*"[^>]*>([\d,]+)</div>'
216
+ matches = re.findall(pattern, html_content)
217
+
218
+ if matches:
219
+ # Take the first match and remove commas
220
+ count_str = matches[0].replace(',', '')
221
+ total_models = int(count_str)
222
+
223
+ logger.info(f"Extracted model count from HF models page: {total_models}")
224
+
225
+ return {
226
+ "total_models": total_models,
227
+ "timestamp": datetime.utcnow().isoformat(),
228
+ "source": "hf_models_page",
229
+ "models_by_library": {},
230
+ "models_by_pipeline": {},
231
+ "models_by_author": {}
232
+ }
233
+ else:
234
+ # Fallback: try to find the number in the window.__hf_deferred object
235
+ # The page has: window.__hf_deferred["numTotalItems"] = 2249312;
236
+ deferred_pattern = r'window\.__hf_deferred\["numTotalItems"\]\s*=\s*(\d+);'
237
+ deferred_matches = re.findall(deferred_pattern, html_content)
238
+
239
+ if deferred_matches:
240
+ total_models = int(deferred_matches[0])
241
+ logger.info(f"Extracted model count from window.__hf_deferred: {total_models}")
242
+
243
+ return {
244
+ "total_models": total_models,
245
+ "timestamp": datetime.utcnow().isoformat(),
246
+ "source": "hf_models_page_deferred",
247
+ "models_by_library": {},
248
+ "models_by_pipeline": {},
249
+ "models_by_author": {}
250
+ }
251
+
252
+ logger.warning("Could not find model count in HF models page HTML")
253
+ return None
254
+
255
+ except httpx.HTTPError as e:
256
+ logger.error(f"HTTP error fetching HF models page: {e}", exc_info=True)
257
+ return None
258
+ except Exception as e:
259
+ logger.error(f"Error extracting count from HF models page: {e}", exc_info=True)
260
+ return None
261
+
262
  def get_count_from_dataset_snapshot(self, dataset_name: str = "modelbiome/ai_ecosystem_withmodelcards") -> Optional[Dict]:
263
  """
264
  Alternative method: Get count from dataset snapshot (like ai-ecosystem repo does).
 
273
  try:
274
  from datasets import load_dataset
275
 
 
276
  dataset = load_dataset(dataset_name, split="train")
277
  total_count = len(dataset)
278
 
 
279
  sample_size = min(10000, total_count)
280
  sample = dataset.shuffle(seed=42).select(range(sample_size))
281
 
 
291
  pipeline = item['pipeline_tag']
292
  pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1
293
 
 
294
  if sample_size < total_count:
295
  scale_factor = total_count / sample_size
296
  library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()}
 
304
  "source": "dataset_snapshot"
305
  }
306
  except Exception as e:
307
+ logger.error(f"Error loading from dataset snapshot: {e}", exc_info=True)
308
  return None
309
 
310
  def record_count(self, count_data: Optional[Dict] = None, source: str = "api") -> bool:
 
344
  conn.close()
345
  return True
346
  except Exception as e:
347
+ logger.error(f"Error recording count: {e}", exc_info=True)
348
  return False
349
 
350
  def get_historical_counts(
 
394
  conn.close()
395
  return results
396
  except Exception as e:
397
+ logger.error(f"Error fetching historical counts: {e}", exc_info=True)
398
  return []
399
 
400
  def get_latest_count(self) -> Optional[Dict]:
 
423
  }
424
  return None
425
  except Exception as e:
426
+ logger.error(f"Error fetching latest count: {e}", exc_info=True)
427
  return None
428
 
429
  def get_growth_stats(self, days: int = 7) -> Dict:
backend/utils/data_loader.py CHANGED
@@ -50,18 +50,16 @@ class ModelDataLoader:
50
  else:
51
  df = df.copy()
52
 
53
- # Fill NaN values
54
  text_fields = ['tags', 'pipeline_tag', 'library_name', 'modelCard']
55
  for field in text_fields:
56
  if field in df.columns:
57
  df[field] = df[field].fillna('')
58
 
59
- # Combine text fields for embedding
60
  df['combined_text'] = (
61
  df.get('tags', '').astype(str) + ' ' +
62
  df.get('pipeline_tag', '').astype(str) + ' ' +
63
  df.get('library_name', '').astype(str) + ' ' +
64
- df['modelCard'].astype(str).str[:500] # Limit modelCard to first 500 chars
65
  )
66
 
67
  return df
@@ -94,7 +92,6 @@ class ModelDataLoader:
94
  else:
95
  df = df.copy()
96
 
97
- # Optimized filtering with vectorized operations
98
  if min_downloads is not None:
99
  downloads_col = df.get('downloads', pd.Series([0] * len(df), index=df.index))
100
  df = df[downloads_col >= min_downloads]
 
50
  else:
51
  df = df.copy()
52
 
 
53
  text_fields = ['tags', 'pipeline_tag', 'library_name', 'modelCard']
54
  for field in text_fields:
55
  if field in df.columns:
56
  df[field] = df[field].fillna('')
57
 
 
58
  df['combined_text'] = (
59
  df.get('tags', '').astype(str) + ' ' +
60
  df.get('pipeline_tag', '').astype(str) + ' ' +
61
  df.get('library_name', '').astype(str) + ' ' +
62
+ df['modelCard'].astype(str).str[:500]
63
  )
64
 
65
  return df
 
92
  else:
93
  df = df.copy()
94
 
 
95
  if min_downloads is not None:
96
  downloads_col = df.get('downloads', pd.Series([0] * len(df), index=df.index))
97
  df = df[downloads_col >= min_downloads]
backend/utils/embeddings.py CHANGED
@@ -27,7 +27,7 @@ class ModelEmbedder:
27
  def generate_embeddings(
28
  self,
29
  texts: List[str],
30
- batch_size: int = 128, # Increased default batch size for speed
31
  show_progress: bool = True
32
  ) -> np.ndarray:
33
  """
 
27
  def generate_embeddings(
28
  self,
29
  texts: List[str],
30
+ batch_size: int = 128,
31
  show_progress: bool = True
32
  ) -> np.ndarray:
33
  """
backend/utils/family_tree.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Family tree utility functions."""
2
+ import pandas as pd
3
+ from typing import Dict
4
+
5
+ def calculate_family_depths(df: pd.DataFrame) -> Dict[str, int]:
6
+ """Calculate family depth for each model."""
7
+ depths = {}
8
+ computing = set()
9
+
10
+ def get_depth(model_id: str) -> int:
11
+ if model_id in depths:
12
+ return depths[model_id]
13
+ if model_id in computing:
14
+ depths[model_id] = 0
15
+ return 0
16
+
17
+ computing.add(model_id)
18
+
19
+ try:
20
+ if df.index.name == 'model_id':
21
+ row = df.loc[model_id]
22
+ else:
23
+ rows = df[df.get('model_id', '') == model_id]
24
+ if len(rows) == 0:
25
+ depths[model_id] = 0
26
+ computing.remove(model_id)
27
+ return 0
28
+ row = rows.iloc[0]
29
+
30
+ parent_id = row.get('parent_model')
31
+ if parent_id and pd.notna(parent_id):
32
+ parent_str = str(parent_id)
33
+ if parent_str != 'nan' and parent_str != '':
34
+ if df.index.name == 'model_id' and parent_str in df.index:
35
+ depth = get_depth(parent_str) + 1
36
+ elif df.index.name != 'model_id':
37
+ parent_rows = df[df.get('model_id', '') == parent_str]
38
+ if len(parent_rows) > 0:
39
+ depth = get_depth(parent_str) + 1
40
+ else:
41
+ depth = 0
42
+ else:
43
+ depth = 0
44
+ else:
45
+ depth = 0
46
+ else:
47
+ depth = 0
48
+ except (KeyError, IndexError):
49
+ depth = 0
50
+
51
+ depths[model_id] = depth
52
+ computing.remove(model_id)
53
+ return depth
54
+
55
+ if df.index.name == 'model_id':
56
+ for model_id in df.index:
57
+ if model_id not in depths:
58
+ get_depth(str(model_id))
59
+ else:
60
+ for _, row in df.iterrows():
61
+ model_id = str(row.get('model_id', ''))
62
+ if model_id and model_id not in depths:
63
+ get_depth(model_id)
64
+
65
+ return depths
66
+
backend/utils/graph_embeddings.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graph-aware embeddings for hierarchical model relationships.
3
+ Uses Node2Vec to create embeddings that respect family tree structure.
4
+ """
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing import Dict, List, Optional, Tuple
8
+ import networkx as nx
9
+ import pickle
10
+ import os
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ try:
16
+ from node2vec import Node2Vec
17
+ NODE2VEC_AVAILABLE = True
18
+ except ImportError:
19
+ NODE2VEC_AVAILABLE = False
20
+ logger.warning("node2vec not available. Install with: pip install node2vec")
21
+
22
+
23
+ class GraphEmbedder:
24
+ """
25
+ Generate graph embeddings that respect hierarchical relationships.
26
+ Combines text embeddings with graph structure embeddings.
27
+ """
28
+
29
+ def __init__(self, dimensions: int = 128, walk_length: int = 30, num_walks: int = 200):
30
+ """
31
+ Initialize graph embedder.
32
+
33
+ Args:
34
+ dimensions: Embedding dimensions
35
+ walk_length: Length of random walks
36
+ num_walks: Number of walks per node
37
+ """
38
+ self.dimensions = dimensions
39
+ self.walk_length = walk_length
40
+ self.num_walks = num_walks
41
+ self.graph: Optional[nx.DiGraph] = None
42
+ self.embeddings: Optional[np.ndarray] = None
43
+ self.model: Optional[Node2Vec] = None
44
+
45
+ def build_family_graph(self, df: pd.DataFrame) -> nx.DiGraph:
46
+ """
47
+ Build directed graph from family relationships.
48
+
49
+ Args:
50
+ df: DataFrame with model_id and parent_model columns
51
+
52
+ Returns:
53
+ NetworkX DiGraph
54
+ """
55
+ graph = nx.DiGraph()
56
+
57
+ for idx, row in df.iterrows():
58
+ model_id = str(row.get('model_id', idx))
59
+ graph.add_node(model_id)
60
+
61
+ parent_id = row.get('parent_model')
62
+ if parent_id and pd.notna(parent_id):
63
+ parent_str = str(parent_id)
64
+ if parent_str != 'nan' and parent_str != '':
65
+ graph.add_edge(parent_str, model_id)
66
+
67
+ self.graph = graph
68
+ logger.info(f"Built graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
69
+ return graph
70
+
71
+ def generate_graph_embeddings(
72
+ self,
73
+ graph: Optional[nx.DiGraph] = None,
74
+ workers: int = 4
75
+ ) -> Dict[str, np.ndarray]:
76
+ """
77
+ Generate Node2Vec embeddings for graph nodes.
78
+
79
+ Args:
80
+ graph: NetworkX graph (uses self.graph if None)
81
+ workers: Number of parallel workers
82
+
83
+ Returns:
84
+ Dictionary mapping model_id to embedding vector
85
+ """
86
+ if not NODE2VEC_AVAILABLE:
87
+ logger.warning("Node2Vec not available, returning empty embeddings")
88
+ return {}
89
+
90
+ if graph is None:
91
+ graph = self.graph
92
+
93
+ if graph is None or graph.number_of_nodes() == 0:
94
+ logger.warning("No graph available for embedding generation")
95
+ return {}
96
+
97
+ try:
98
+ node2vec = Node2Vec(
99
+ graph,
100
+ dimensions=self.dimensions,
101
+ walk_length=self.walk_length,
102
+ num_walks=self.num_walks,
103
+ workers=workers
104
+ )
105
+
106
+ model = node2vec.fit(window=10, min_count=1, batch_words=4)
107
+ self.model = model
108
+
109
+ embeddings_dict = {}
110
+ for node in graph.nodes():
111
+ if node in model.wv:
112
+ embeddings_dict[node] = model.wv[node]
113
+
114
+ logger.info(f"Generated graph embeddings for {len(embeddings_dict)} nodes")
115
+ return embeddings_dict
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error generating graph embeddings: {e}", exc_info=True)
119
+ return {}
120
+
121
+ def combine_embeddings(
122
+ self,
123
+ text_embeddings: np.ndarray,
124
+ graph_embeddings: Dict[str, np.ndarray],
125
+ model_ids: List[str],
126
+ text_weight: float = 0.7,
127
+ graph_weight: float = 0.3
128
+ ) -> np.ndarray:
129
+ """
130
+ Combine text and graph embeddings with weighted average.
131
+
132
+ Args:
133
+ text_embeddings: Text-based embeddings (n_samples, text_dim)
134
+ graph_embeddings: Graph embeddings dictionary
135
+ model_ids: List of model IDs corresponding to text_embeddings
136
+ text_weight: Weight for text embeddings
137
+ graph_weight: Weight for graph embeddings
138
+
139
+ Returns:
140
+ Combined embeddings (n_samples, combined_dim)
141
+ """
142
+ if not graph_embeddings:
143
+ return text_embeddings
144
+
145
+ text_dim = text_embeddings.shape[1]
146
+ graph_dim = next(iter(graph_embeddings.values())).shape[0]
147
+
148
+ combined = np.zeros((len(model_ids), text_dim + graph_dim))
149
+
150
+ for i, model_id in enumerate(model_ids):
151
+ model_id_str = str(model_id)
152
+
153
+ text_emb = text_embeddings[i]
154
+ graph_emb = graph_embeddings.get(model_id_str, np.zeros(graph_dim))
155
+
156
+ normalized_text = text_emb / (np.linalg.norm(text_emb) + 1e-8)
157
+ normalized_graph = graph_emb / (np.linalg.norm(graph_emb) + 1e-8)
158
+
159
+ combined[i] = np.concatenate([
160
+ normalized_text * text_weight,
161
+ normalized_graph * graph_weight
162
+ ])
163
+
164
+ return combined
165
+
166
+ def save_embeddings(self, embeddings: Dict[str, np.ndarray], filepath: str):
167
+ """Save graph embeddings to disk."""
168
+ os.makedirs(os.path.dirname(filepath) if os.path.dirname(filepath) else '.', exist_ok=True)
169
+ with open(filepath, 'wb') as f:
170
+ pickle.dump(embeddings, f)
171
+
172
+ def load_embeddings(self, filepath: str) -> Dict[str, np.ndarray]:
173
+ """Load graph embeddings from disk."""
174
+ with open(filepath, 'rb') as f:
175
+ return pickle.load(f)
176
+
177
+
backend/utils/network_analysis.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Network analysis module inspired by Open Syllabus Project.
3
  Builds co-occurrence networks for models based on shared contexts.
 
4
  """
5
  import pandas as pd
6
  import numpy as np
@@ -8,12 +9,66 @@ from collections import Counter
8
  from itertools import combinations
9
  from typing import List, Dict, Tuple, Optional, Set
10
  import networkx as nx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  class ModelNetworkBuilder:
14
  """
15
  Build network graphs for models based on co-occurrence patterns.
16
  Similar to Open Syllabus approach of connecting texts that appear together.
 
17
  """
18
 
19
  def __init__(self, df: pd.DataFrame):
@@ -22,13 +77,13 @@ class ModelNetworkBuilder:
22
 
23
  Args:
24
  df: DataFrame with model data including model_id, library_name,
25
- pipeline_tag, tags, parent_model, downloads, likes
 
26
  """
27
  self.df = df.copy()
28
  if 'model_id' not in self.df.columns:
29
  raise ValueError("DataFrame must contain 'model_id' column")
30
 
31
- # Ensure model_id is index for fast lookups
32
  if self.df.index.name != 'model_id':
33
  if 'model_id' in self.df.columns:
34
  self.df.set_index('model_id', drop=False, inplace=True)
@@ -208,23 +263,41 @@ class ModelNetworkBuilder:
208
  def build_family_tree_network(
209
  self,
210
  root_model_id: str,
211
- max_depth: int = 5
 
 
212
  ) -> nx.DiGraph:
213
  """
214
- Build directed graph of model family tree.
215
 
216
  Args:
217
  root_model_id: Root model to start from
218
- max_depth: Maximum depth to traverse
 
 
 
219
 
220
  Returns:
221
- NetworkX DiGraph representing family tree
222
  """
223
  graph = nx.DiGraph()
224
  visited = set()
225
 
226
- def add_family(current_id: str, depth: int):
227
- if depth <= 0 or current_id in visited:
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  return
229
  visited.add(current_id)
230
 
@@ -233,28 +306,98 @@ class ModelNetworkBuilder:
233
 
234
  row = self.df.loc[current_id]
235
 
236
- # Add node
237
  graph.add_node(str(current_id))
238
  graph.nodes[str(current_id)]['title'] = self._format_title(current_id)
239
  graph.nodes[str(current_id)]['freq'] = int(row.get('downloads', 0))
 
 
 
 
240
 
241
- # Add edge to parent
242
- parent_id = row.get('parent_model')
243
- if parent_id and pd.notna(parent_id) and str(parent_id) != 'nan':
244
- parent_id_str = str(parent_id)
245
- graph.add_edge(parent_id_str, str(current_id))
246
- add_family(parent_id_str, depth - 1)
247
 
248
- # Add edges to children
249
- children = self.df[self.df.get('parent_model', '') == current_id]
250
- for child_id, child_row in children.iterrows():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  if str(child_id) not in visited:
252
- graph.add_edge(str(current_id), str(child_id))
253
- add_family(str(child_id), depth - 1)
 
 
 
 
 
 
 
 
254
 
255
  add_family(root_model_id, max_depth)
 
 
 
 
256
  return graph
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  def export_graphml(self, graph: nx.Graph, filename: str):
259
  """Export graph to GraphML format (like Open Syllabus)."""
260
  nx.write_graphml(graph, filename)
 
1
  """
2
  Network analysis module inspired by Open Syllabus Project.
3
  Builds co-occurrence networks for models based on shared contexts.
4
+ Supports multiple relationship types: finetune, quantized, adapter, merge.
5
  """
6
  import pandas as pd
7
  import numpy as np
 
9
  from itertools import combinations
10
  from typing import List, Dict, Tuple, Optional, Set
11
  import networkx as nx
12
+ import ast
13
+ from datetime import datetime
14
+
15
+
16
+ def _parse_parent_list(value) -> List[str]:
17
+ """
18
+ Parse parent model list from string/eval format.
19
+ Handles both string representations and actual lists.
20
+ """
21
+ if pd.isna(value) or value == '' or str(value) == 'nan':
22
+ return []
23
+
24
+ try:
25
+ if isinstance(value, str):
26
+ if value.startswith('[') or value.startswith('('):
27
+ parsed = ast.literal_eval(value)
28
+ else:
29
+ parsed = [value]
30
+ else:
31
+ parsed = value
32
+
33
+ if isinstance(parsed, list):
34
+ return [str(p) for p in parsed if p and str(p) != 'nan']
35
+ elif parsed:
36
+ return [str(parsed)]
37
+ else:
38
+ return []
39
+ except (ValueError, SyntaxError):
40
+ return []
41
+
42
+
43
+ def _get_all_parents(row: pd.Series) -> Dict[str, List[str]]:
44
+ """
45
+ Extract all parent types from a row.
46
+ Returns dict mapping relationship type to list of parent IDs.
47
+ """
48
+ parents = {}
49
+
50
+ parent_columns = {
51
+ 'parent_model': 'parent',
52
+ 'finetune_parent': 'finetune',
53
+ 'quantized_parent': 'quantized',
54
+ 'adapter_parent': 'adapter',
55
+ 'merge_parent': 'merge'
56
+ }
57
+
58
+ for col, rel_type in parent_columns.items():
59
+ if col in row:
60
+ parent_list = _parse_parent_list(row.get(col))
61
+ if parent_list:
62
+ parents[rel_type] = parent_list
63
+
64
+ return parents
65
 
66
 
67
  class ModelNetworkBuilder:
68
  """
69
  Build network graphs for models based on co-occurrence patterns.
70
  Similar to Open Syllabus approach of connecting texts that appear together.
71
+ Supports multiple relationship types: finetune, quantized, adapter, merge.
72
  """
73
 
74
  def __init__(self, df: pd.DataFrame):
 
77
 
78
  Args:
79
  df: DataFrame with model data including model_id, library_name,
80
+ pipeline_tag, tags, parent_model, finetune_parent, quantized_parent,
81
+ adapter_parent, merge_parent, downloads, likes, createdAt
82
  """
83
  self.df = df.copy()
84
  if 'model_id' not in self.df.columns:
85
  raise ValueError("DataFrame must contain 'model_id' column")
86
 
 
87
  if self.df.index.name != 'model_id':
88
  if 'model_id' in self.df.columns:
89
  self.df.set_index('model_id', drop=False, inplace=True)
 
263
  def build_family_tree_network(
264
  self,
265
  root_model_id: str,
266
+ max_depth: Optional[int] = 5,
267
+ include_edge_attributes: bool = True,
268
+ filter_edge_types: Optional[List[str]] = None
269
  ) -> nx.DiGraph:
270
  """
271
+ Build directed graph of model family tree with multiple relationship types.
272
 
273
  Args:
274
  root_model_id: Root model to start from
275
+ max_depth: Maximum depth to traverse. If None, traverses entire tree without limit.
276
+ include_edge_attributes: Whether to calculate edge attributes (change in likes, downloads, etc.)
277
+ filter_edge_types: List of edge types to include (e.g., ['finetune', 'quantized']).
278
+ If None, includes all types.
279
 
280
  Returns:
281
+ NetworkX DiGraph representing family tree with edge types and attributes
282
  """
283
  graph = nx.DiGraph()
284
  visited = set()
285
 
286
+ children_index: Dict[str, List[Tuple[str, str]]] = {}
287
+ for idx, row in self.df.iterrows():
288
+ model_id = str(row.get('model_id', idx))
289
+ all_parents = _get_all_parents(row)
290
+
291
+ for rel_type, parent_list in all_parents.items():
292
+ for parent_id in parent_list:
293
+ if parent_id not in children_index:
294
+ children_index[parent_id] = []
295
+ children_index[parent_id].append((model_id, rel_type))
296
+
297
+ def add_family(current_id: str, depth: Optional[int]):
298
+ if current_id in visited:
299
+ return
300
+ if depth is not None and depth <= 0:
301
  return
302
  visited.add(current_id)
303
 
 
306
 
307
  row = self.df.loc[current_id]
308
 
 
309
  graph.add_node(str(current_id))
310
  graph.nodes[str(current_id)]['title'] = self._format_title(current_id)
311
  graph.nodes[str(current_id)]['freq'] = int(row.get('downloads', 0))
312
+ graph.nodes[str(current_id)]['likes'] = int(row.get('likes', 0))
313
+ graph.nodes[str(current_id)]['downloads'] = int(row.get('downloads', 0))
314
+ graph.nodes[str(current_id)]['library'] = str(row.get('library_name', '')) if pd.notna(row.get('library_name')) else ''
315
+ graph.nodes[str(current_id)]['pipeline'] = str(row.get('pipeline_tag', '')) if pd.notna(row.get('pipeline_tag')) else ''
316
 
317
+ createdAt = row.get('createdAt')
318
+ if pd.notna(createdAt):
319
+ graph.nodes[str(current_id)]['createdAt'] = str(createdAt)
 
 
 
320
 
321
+ all_parents = _get_all_parents(row)
322
+ for rel_type, parent_list in all_parents.items():
323
+ if filter_edge_types and rel_type not in filter_edge_types:
324
+ continue
325
+
326
+ for parent_id in parent_list:
327
+ if parent_id in self.df.index:
328
+ graph.add_edge(parent_id, str(current_id))
329
+ graph[parent_id][str(current_id)]['edge_types'] = [rel_type]
330
+ graph[parent_id][str(current_id)]['edge_type'] = rel_type
331
+
332
+ next_depth = depth - 1 if depth is not None else None
333
+ add_family(parent_id, next_depth)
334
+
335
+ children = children_index.get(current_id, [])
336
+ for child_id, rel_type in children:
337
+ if filter_edge_types and rel_type not in filter_edge_types:
338
+ continue
339
+
340
  if str(child_id) not in visited:
341
+ if not graph.has_edge(str(current_id), child_id):
342
+ graph.add_edge(str(current_id), child_id)
343
+ graph[str(current_id)][child_id]['edge_types'] = [rel_type]
344
+ graph[str(current_id)][child_id]['edge_type'] = rel_type
345
+ else:
346
+ if rel_type not in graph[str(current_id)][child_id].get('edge_types', []):
347
+ graph[str(current_id)][child_id]['edge_types'].append(rel_type)
348
+
349
+ next_depth = depth - 1 if depth is not None else None
350
+ add_family(child_id, next_depth)
351
 
352
  add_family(root_model_id, max_depth)
353
+
354
+ if include_edge_attributes:
355
+ self._add_edge_attributes(graph)
356
+
357
  return graph
358
 
359
+ def _add_edge_attributes(self, graph: nx.DiGraph):
360
+ """
361
+ Add edge attributes like change in likes, downloads, time difference.
362
+ Similar to the notebook's edge attribute calculation.
363
+ """
364
+ for edge in graph.edges():
365
+ parent_model = edge[0]
366
+ model_id = edge[1]
367
+
368
+ if parent_model not in graph.nodes() or model_id not in graph.nodes():
369
+ continue
370
+
371
+ parent_likes = graph.nodes[parent_model].get('likes', 0)
372
+ model_likes = graph.nodes[model_id].get('likes', 0)
373
+ parent_downloads = graph.nodes[parent_model].get('downloads', 0)
374
+ model_downloads = graph.nodes[model_id].get('downloads', 0)
375
+
376
+ graph.edges[edge]['change_in_likes'] = model_likes - parent_likes
377
+ if parent_likes != 0:
378
+ graph.edges[edge]['percentage_change_in_likes'] = (model_likes - parent_likes) / parent_likes
379
+ else:
380
+ graph.edges[edge]['percentage_change_in_likes'] = np.nan
381
+
382
+ graph.edges[edge]['change_in_downloads'] = model_downloads - parent_downloads
383
+ if parent_downloads != 0:
384
+ graph.edges[edge]['percentage_change_in_downloads'] = (model_downloads - parent_downloads) / parent_downloads
385
+ else:
386
+ graph.edges[edge]['percentage_change_in_downloads'] = np.nan
387
+
388
+ parent_created = graph.nodes[parent_model].get('createdAt')
389
+ model_created = graph.nodes[model_id].get('createdAt')
390
+
391
+ if parent_created and model_created:
392
+ try:
393
+ parent_dt = datetime.strptime(str(parent_created), '%Y-%m-%dT%H:%M:%S.%fZ')
394
+ model_dt = datetime.strptime(str(model_created), '%Y-%m-%dT%H:%M:%S.%fZ')
395
+ graph.edges[edge]['change_in_createdAt_days'] = (model_dt - parent_dt).days
396
+ except (ValueError, TypeError):
397
+ graph.edges[edge]['change_in_createdAt_days'] = np.nan
398
+ else:
399
+ graph.edges[edge]['change_in_createdAt_days'] = np.nan
400
+
401
  def export_graphml(self, graph: nx.Graph, filename: str):
402
  """Export graph to GraphML format (like Open Syllabus)."""
403
  nx.write_graphml(graph, filename)
frontend/.npmrc CHANGED
@@ -1,2 +1,4 @@
1
  legacy-peer-deps=true
2
 
 
 
 
1
  legacy-peer-deps=true
2
 
3
+
4
+
frontend/package-lock.json CHANGED
@@ -32,7 +32,8 @@
32
  "react-dom": "^18.2.0",
33
  "react-scripts": "5.0.1",
34
  "three": "^0.160.1",
35
- "typescript": "^5.0.0"
 
36
  }
37
  },
38
  "node_modules/@alloc/quick-lru": {
 
32
  "react-dom": "^18.2.0",
33
  "react-scripts": "5.0.1",
34
  "three": "^0.160.1",
35
+ "typescript": "^5.0.0",
36
+ "zustand": "^5.0.8"
37
  }
38
  },
39
  "node_modules/@alloc/quick-lru": {
frontend/package.json CHANGED
@@ -28,7 +28,8 @@
28
  "react-dom": "^18.2.0",
29
  "react-scripts": "5.0.1",
30
  "three": "^0.160.1",
31
- "typescript": "^5.0.0"
 
32
  },
33
  "scripts": {
34
  "start": "react-scripts start",
 
28
  "react-dom": "^18.2.0",
29
  "react-scripts": "5.0.1",
30
  "three": "^0.160.1",
31
+ "typescript": "^5.0.0",
32
+ "zustand": "^5.0.8"
33
  },
34
  "scripts": {
35
  "start": "react-scripts start",
frontend/public/index.html CHANGED
@@ -10,7 +10,7 @@
10
  />
11
  <link rel="preconnect" href="https://fonts.googleapis.com">
12
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
13
- <link href="https://fonts.googleapis.com/css2?family=Vend+Sans:wght@300;400;500;600;700&display=swap" rel="stylesheet">
14
  <title>Anatomy of a Machine Learning Ecosystem: 2 Million Models on Hugging Face</title>
15
  </head>
16
  <body>
 
10
  />
11
  <link rel="preconnect" href="https://fonts.googleapis.com">
12
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
13
+ <link href="https://fonts.googleapis.com/css2?family=Instrument+Sans:wght@400;500;600;700&display=swap" rel="stylesheet">
14
  <title>Anatomy of a Machine Learning Ecosystem: 2 Million Models on Hugging Face</title>
15
  </head>
16
  <body>
frontend/src/App.css CHANGED
@@ -7,86 +7,24 @@
7
  }
8
 
9
  .App-header {
10
- background: linear-gradient(135deg, #1a237e 0%, #283593 20%, #3949ab 40%, #5e35b1 60%, #7b1fa2 80%, #6a1b9a 100%);
11
- background-size: 200% 200%;
12
- animation: gradientShift 20s ease infinite;
13
  color: #ffffff;
14
- padding: 3rem 2.5rem;
15
  text-align: center;
16
- border-bottom: 2px solid rgba(100, 181, 246, 0.3);
17
- box-shadow: 0 4px 20px rgba(0, 0, 0, 0.25), 0 2px 10px rgba(123, 31, 162, 0.3);
18
  position: relative;
19
- overflow: hidden;
20
  }
21
 
22
- .App-header::before {
23
- content: '';
24
- position: absolute;
25
- top: 0;
26
- left: 0;
27
- right: 0;
28
- bottom: 0;
29
- background:
30
- radial-gradient(circle at 20% 50%, rgba(100, 181, 246, 0.15) 0%, transparent 50%),
31
- radial-gradient(circle at 80% 80%, rgba(156, 39, 176, 0.1) 0%, transparent 50%),
32
- radial-gradient(circle at 40% 20%, rgba(33, 150, 243, 0.1) 0%, transparent 50%);
33
- pointer-events: none;
34
- animation: pulse 8s ease-in-out infinite;
35
- }
36
-
37
- .App-header::after {
38
- content: '';
39
- position: absolute;
40
- top: 0;
41
- left: 0;
42
- right: 0;
43
- bottom: 0;
44
- background-image:
45
- repeating-linear-gradient(
46
- 0deg,
47
- transparent,
48
- transparent 2px,
49
- rgba(255, 255, 255, 0.03) 2px,
50
- rgba(255, 255, 255, 0.03) 4px
51
- );
52
- pointer-events: none;
53
- opacity: 0.5;
54
- }
55
-
56
- @keyframes gradientShift {
57
- 0% {
58
- background-position: 0% 50%;
59
- }
60
- 50% {
61
- background-position: 100% 50%;
62
- }
63
- 100% {
64
- background-position: 0% 50%;
65
- }
66
- }
67
-
68
- @keyframes pulse {
69
- 0%, 100% {
70
- opacity: 1;
71
- }
72
- 50% {
73
- opacity: 0.8;
74
- }
75
- }
76
 
77
  .App-header h1 {
78
  margin: 0 0 1rem 0;
79
- font-size: 2.25rem;
80
- font-weight: 700;
81
- letter-spacing: -0.02em;
82
- line-height: 1.2;
83
- position: relative;
84
- z-index: 1;
85
- text-shadow: 0 2px 8px rgba(0, 0, 0, 0.4), 0 4px 16px rgba(123, 31, 162, 0.3);
86
- background: linear-gradient(180deg, #ffffff 0%, #e1bee7 100%);
87
- -webkit-background-clip: text;
88
- -webkit-text-fill-color: transparent;
89
- background-clip: text;
90
  }
91
 
92
  .App-header p {
@@ -122,23 +60,17 @@
122
  }
123
 
124
  .stats span {
125
- padding: 0.75rem 1.5rem;
126
- background: rgba(255, 255, 255, 0.15);
127
- border-radius: 12px;
128
- backdrop-filter: blur(20px);
129
- -webkit-backdrop-filter: blur(20px);
130
- border: 2px solid rgba(255, 255, 255, 0.25);
131
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
132
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1), inset 0 1px 0 rgba(255, 255, 255, 0.3);
133
- font-weight: 600;
134
- letter-spacing: 0.02em;
135
  }
136
 
137
  .stats span:hover {
138
- background: rgba(255, 255, 255, 0.25);
139
- transform: translateY(-2px) scale(1.05);
140
- box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15), inset 0 1px 0 rgba(255, 255, 255, 0.4);
141
- border-color: rgba(255, 255, 255, 0.4);
142
  }
143
 
144
  .main-content {
@@ -149,10 +81,9 @@
149
  .sidebar {
150
  width: 340px;
151
  padding: 1.5rem;
152
- background: linear-gradient(to bottom, #fafafa 0%, #ffffff 100%);
153
  overflow-y: auto;
154
- border-right: 2px solid #e0e0e0;
155
- box-shadow: 2px 0 8px rgba(0, 0, 0, 0.05);
156
  }
157
 
158
  .sidebar h2 {
@@ -164,12 +95,11 @@
164
  }
165
 
166
  .sidebar h3 {
167
- font-size: 0.95rem;
168
- font-weight: 700;
169
- color: #5e35b1;
170
- margin: 0 0 1rem 0;
171
  letter-spacing: -0.01em;
172
- text-transform: none;
173
  }
174
 
175
  .sidebar label {
@@ -202,9 +132,8 @@
202
  .sidebar input[type="text"]:focus,
203
  .sidebar select:focus {
204
  outline: none;
205
- border-color: #5e35b1;
206
- box-shadow: 0 0 0 3px rgba(94, 53, 177, 0.12), 0 2px 6px rgba(0, 0, 0, 0.1);
207
- transform: translateY(-1px);
208
  }
209
 
210
  .sidebar input[type="range"] {
@@ -227,20 +156,20 @@
227
  .sidebar input[type="range"]::-webkit-slider-thumb {
228
  -webkit-appearance: none;
229
  appearance: none;
230
- width: 20px;
231
- height: 20px;
232
  border-radius: 50%;
233
- background: linear-gradient(135deg, #5e35b1 0%, #7b1fa2 100%);
234
  cursor: pointer;
235
- box-shadow: 0 2px 6px rgba(94, 53, 177, 0.3), 0 4px 12px rgba(94, 53, 177, 0.2);
236
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
237
- border: 3px solid #ffffff;
238
  }
239
 
240
  .sidebar input[type="range"]::-webkit-slider-thumb:hover {
241
- background: linear-gradient(135deg, #512da8 0%, #6a1b9a 100%);
242
- transform: scale(1.2);
243
- box-shadow: 0 3px 8px rgba(94, 53, 177, 0.4), 0 6px 16px rgba(94, 53, 177, 0.3);
244
  }
245
 
246
  .sidebar input[type="range"]::-webkit-slider-thumb:active {
@@ -248,20 +177,20 @@
248
  }
249
 
250
  .sidebar input[type="range"]::-moz-range-thumb {
251
- width: 20px;
252
- height: 20px;
253
  border-radius: 50%;
254
- background: linear-gradient(135deg, #5e35b1 0%, #7b1fa2 100%);
255
  cursor: pointer;
256
- border: 3px solid #ffffff;
257
- box-shadow: 0 2px 6px rgba(94, 53, 177, 0.3), 0 4px 12px rgba(94, 53, 177, 0.2);
258
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
259
  }
260
 
261
  .sidebar input[type="range"]::-moz-range-thumb:hover {
262
- background: linear-gradient(135deg, #512da8 0%, #6a1b9a 100%);
263
- transform: scale(1.2);
264
- box-shadow: 0 3px 8px rgba(94, 53, 177, 0.4), 0 6px 16px rgba(94, 53, 177, 0.3);
265
  }
266
 
267
  .sidebar input[type="range"]::-moz-range-thumb:active {
@@ -288,17 +217,16 @@
288
 
289
  .sidebar-section {
290
  background: #ffffff;
291
- border-radius: 8px;
292
  padding: 1.25rem;
293
- margin-bottom: 1.25rem;
294
  border: 1px solid #e0e0e0;
295
- box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
296
- transition: all 0.3s ease;
297
  }
298
 
299
  .sidebar-section:hover {
300
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.12);
301
  border-color: #d0d0d0;
 
302
  }
303
 
304
  .filter-chip {
@@ -380,22 +308,21 @@
380
  }
381
 
382
  .loading {
383
- color: #5e35b1;
384
  font-weight: 600;
385
- background: linear-gradient(135deg, #f5f3ff 0%, #ede7f6 100%);
386
- border: 2px solid #d1c4e9;
387
- box-shadow: 0 4px 12px rgba(94, 53, 177, 0.1);
388
  }
389
 
390
  .loading::after {
391
  content: '';
392
- width: 48px;
393
- height: 48px;
394
- border: 5px solid #e1bee7;
395
- border-top-color: #5e35b1;
396
- border-right-color: #7b1fa2;
397
  border-radius: 50%;
398
- animation: spin 0.8s cubic-bezier(0.68, -0.55, 0.265, 1.55) infinite;
399
  }
400
 
401
  @keyframes spin {
@@ -403,101 +330,62 @@
403
  }
404
 
405
  .error {
406
- color: #c62828;
407
- background: linear-gradient(135deg, #ffebee 0%, #ffcdd2 100%);
408
- border-radius: 12px;
409
- border: 2px solid #ef5350;
410
  max-width: 550px;
411
  margin: 0 auto;
412
- box-shadow: 0 4px 12px rgba(198, 40, 40, 0.15);
413
  font-weight: 500;
414
  }
415
 
416
- .error::before {
417
- content: '⚠️';
418
- font-size: 2.5rem;
419
- display: block;
420
- margin-bottom: 0.5rem;
421
- }
422
-
423
  .empty {
424
- color: #616161;
425
- background: linear-gradient(135deg, #fafafa 0%, #f5f5f5 100%);
426
- border-radius: 12px;
427
- border: 2px solid #e0e0e0;
428
  max-width: 550px;
429
  margin: 0 auto;
430
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08);
431
  font-weight: 500;
432
  }
433
 
434
- .empty::before {
435
- content: 'πŸ”';
436
- font-size: 2.5rem;
437
- display: block;
438
- margin-bottom: 0.5rem;
439
- }
440
-
441
  .btn {
442
  padding: 0.625rem 1.25rem;
443
- border-radius: 6px;
444
  border: none;
445
  font-size: 0.9rem;
446
  font-weight: 600;
447
  cursor: pointer;
448
- transition: all 0.25s cubic-bezier(0.4, 0, 0.2, 1);
449
  font-family: 'Instrument Sans', sans-serif;
450
  display: inline-flex;
451
  align-items: center;
452
  justify-content: center;
453
  gap: 0.5rem;
454
- position: relative;
455
- overflow: hidden;
456
- }
457
-
458
- .btn::before {
459
- content: '';
460
- position: absolute;
461
- top: 50%;
462
- left: 50%;
463
- width: 0;
464
- height: 0;
465
- border-radius: 50%;
466
- background: rgba(255, 255, 255, 0.3);
467
- transform: translate(-50%, -50%);
468
- transition: width 0.6s, height 0.6s;
469
  }
470
 
471
- .btn:hover::before {
472
- width: 300px;
473
- height: 300px;
474
- }
475
 
476
  .btn-primary {
477
- background: linear-gradient(135deg, #5e35b1 0%, #7b1fa2 100%);
478
  color: white;
479
- box-shadow: 0 2px 4px rgba(94, 53, 177, 0.3);
480
  }
481
 
482
  .btn-primary:hover {
483
- background: linear-gradient(135deg, #512da8 0%, #6a1b9a 100%);
484
- transform: translateY(-2px);
485
- box-shadow: 0 4px 12px rgba(94, 53, 177, 0.4);
486
  }
487
 
488
  .btn-secondary {
489
  background: #f5f5f5;
490
- color: #1a1a1a;
491
- border: 2px solid #e0e0e0;
492
- box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
493
  }
494
 
495
  .btn-secondary:hover {
496
- background: #ffffff;
497
- border-color: #5e35b1;
498
- color: #5e35b1;
499
- transform: translateY(-1px);
500
- box-shadow: 0 2px 6px rgba(0, 0, 0, 0.12);
501
  }
502
 
503
  .btn-small {
@@ -585,7 +473,7 @@
585
  --text-primary: #ffffff;
586
  --text-secondary: #cccccc;
587
  --border-color: #444444;
588
- --accent-color: #64b5f6;
589
  }
590
 
591
  [data-theme="light"] {
@@ -595,32 +483,31 @@
595
  --text-primary: #1a1a1a;
596
  --text-secondary: #666666;
597
  --border-color: #d0d0d0;
598
- --accent-color: #1976d2;
599
  }
600
 
601
  /* Random Model Button */
602
  .random-model-btn {
603
  display: flex;
604
  align-items: center;
605
- gap: 0.5rem;
606
- padding: 0.5rem 1rem;
607
- background: var(--accent-color, #4a90e2);
608
  color: white;
609
  border: none;
610
  border-radius: 4px;
611
  cursor: pointer;
612
  font-size: 0.9rem;
613
  font-family: 'Instrument Sans', sans-serif;
614
- font-weight: 500;
615
  transition: all 0.2s;
616
  width: 100%;
617
- justify-content: center;
618
  }
619
 
620
  .random-model-btn:hover:not(:disabled) {
621
- background: var(--accent-color, #357abd);
622
  transform: translateY(-1px);
623
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
624
  }
625
 
626
  .random-model-btn:disabled {
@@ -628,10 +515,6 @@
628
  cursor: not-allowed;
629
  }
630
 
631
- .random-icon {
632
- font-size: 1.1rem;
633
- }
634
-
635
  /* Zoom Slider */
636
  .zoom-slider-container {
637
  margin-bottom: 1rem;
@@ -859,7 +742,7 @@
859
  width: 18px;
860
  height: 18px;
861
  cursor: pointer;
862
- accent-color: #5e35b1;
863
  margin-right: 0.5rem;
864
  }
865
 
 
7
  }
8
 
9
  .App-header {
10
+ background: #2d2d2d;
 
 
11
  color: #ffffff;
12
+ padding: 2.5rem 2rem;
13
  text-align: center;
14
+ border-bottom: 1px solid #404040;
15
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
16
  position: relative;
 
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  .App-header h1 {
21
  margin: 0 0 1rem 0;
22
+ font-size: 2rem;
23
+ font-weight: 600;
24
+ letter-spacing: -0.01em;
25
+ line-height: 1.3;
26
+ color: #ffffff;
27
+ text-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);
 
 
 
 
 
28
  }
29
 
30
  .App-header p {
 
60
  }
61
 
62
  .stats span {
63
+ padding: 0.625rem 1.25rem;
64
+ background: rgba(255, 255, 255, 0.1);
65
+ border-radius: 6px;
66
+ border: 1px solid rgba(255, 255, 255, 0.2);
67
+ transition: all 0.2s ease;
68
+ font-weight: 500;
 
 
 
 
69
  }
70
 
71
  .stats span:hover {
72
+ background: rgba(255, 255, 255, 0.15);
73
+ transform: translateY(-1px);
 
 
74
  }
75
 
76
  .main-content {
 
81
  .sidebar {
82
  width: 340px;
83
  padding: 1.5rem;
84
+ background: #fafafa;
85
  overflow-y: auto;
86
+ border-right: 1px solid #e0e0e0;
 
87
  }
88
 
89
  .sidebar h2 {
 
95
  }
96
 
97
  .sidebar h3 {
98
+ font-size: 0.9rem;
99
+ font-weight: 600;
100
+ color: #2d2d2d;
101
+ margin: 0 0 0.875rem 0;
102
  letter-spacing: -0.01em;
 
103
  }
104
 
105
  .sidebar label {
 
132
  .sidebar input[type="text"]:focus,
133
  .sidebar select:focus {
134
  outline: none;
135
+ border-color: #4a4a4a;
136
+ box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.08);
 
137
  }
138
 
139
  .sidebar input[type="range"] {
 
156
  .sidebar input[type="range"]::-webkit-slider-thumb {
157
  -webkit-appearance: none;
158
  appearance: none;
159
+ width: 18px;
160
+ height: 18px;
161
  border-radius: 50%;
162
+ background: #4a4a4a;
163
  cursor: pointer;
164
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
165
+ transition: all 0.2s ease;
166
+ border: 2px solid #ffffff;
167
  }
168
 
169
  .sidebar input[type="range"]::-webkit-slider-thumb:hover {
170
+ background: #2d2d2d;
171
+ transform: scale(1.1);
172
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
173
  }
174
 
175
  .sidebar input[type="range"]::-webkit-slider-thumb:active {
 
177
  }
178
 
179
  .sidebar input[type="range"]::-moz-range-thumb {
180
+ width: 18px;
181
+ height: 18px;
182
  border-radius: 50%;
183
+ background: #4a4a4a;
184
  cursor: pointer;
185
+ border: 2px solid #ffffff;
186
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
187
+ transition: all 0.2s ease;
188
  }
189
 
190
  .sidebar input[type="range"]::-moz-range-thumb:hover {
191
+ background: #2d2d2d;
192
+ transform: scale(1.1);
193
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
194
  }
195
 
196
  .sidebar input[type="range"]::-moz-range-thumb:active {
 
217
 
218
  .sidebar-section {
219
  background: #ffffff;
220
+ border-radius: 6px;
221
  padding: 1.25rem;
222
+ margin-bottom: 1rem;
223
  border: 1px solid #e0e0e0;
224
+ transition: all 0.2s ease;
 
225
  }
226
 
227
  .sidebar-section:hover {
 
228
  border-color: #d0d0d0;
229
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05);
230
  }
231
 
232
  .filter-chip {
 
308
  }
309
 
310
  .loading {
311
+ color: #2d2d2d;
312
  font-weight: 600;
313
+ background: #f5f5f5;
314
+ border: 1px solid #d0d0d0;
315
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
316
  }
317
 
318
  .loading::after {
319
  content: '';
320
+ width: 40px;
321
+ height: 40px;
322
+ border: 4px solid #e0e0e0;
323
+ border-top-color: #4a4a4a;
 
324
  border-radius: 50%;
325
+ animation: spin 0.8s linear infinite;
326
  }
327
 
328
  @keyframes spin {
 
330
  }
331
 
332
  .error {
333
+ color: #d32f2f;
334
+ background: #ffebee;
335
+ border-radius: 8px;
336
+ border: 1px solid #ffcdd2;
337
  max-width: 550px;
338
  margin: 0 auto;
 
339
  font-weight: 500;
340
  }
341
 
 
 
 
 
 
 
 
342
  .empty {
343
+ color: #6a6a6a;
344
+ background: #f5f5f5;
345
+ border-radius: 8px;
346
+ border: 1px solid #e0e0e0;
347
  max-width: 550px;
348
  margin: 0 auto;
 
349
  font-weight: 500;
350
  }
351
 
 
 
 
 
 
 
 
352
  .btn {
353
  padding: 0.625rem 1.25rem;
354
+ border-radius: 4px;
355
  border: none;
356
  font-size: 0.9rem;
357
  font-weight: 600;
358
  cursor: pointer;
359
+ transition: all 0.2s ease;
360
  font-family: 'Instrument Sans', sans-serif;
361
  display: inline-flex;
362
  align-items: center;
363
  justify-content: center;
364
  gap: 0.5rem;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  }
366
 
 
 
 
 
367
 
368
  .btn-primary {
369
+ background: #2d2d2d;
370
  color: white;
371
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.15);
372
  }
373
 
374
  .btn-primary:hover {
375
+ background: #1a1a1a;
376
+ transform: translateY(-1px);
377
+ box-shadow: 0 3px 8px rgba(0, 0, 0, 0.2);
378
  }
379
 
380
  .btn-secondary {
381
  background: #f5f5f5;
382
+ color: #2d2d2d;
383
+ border: 1px solid #d0d0d0;
 
384
  }
385
 
386
  .btn-secondary:hover {
387
+ background: #e8e8e8;
388
+ border-color: #b0b0b0;
 
 
 
389
  }
390
 
391
  .btn-small {
 
473
  --text-primary: #ffffff;
474
  --text-secondary: #cccccc;
475
  --border-color: #444444;
476
+ --accent-color: #4a4a4a;
477
  }
478
 
479
  [data-theme="light"] {
 
483
  --text-primary: #1a1a1a;
484
  --text-secondary: #666666;
485
  --border-color: #d0d0d0;
486
+ --accent-color: #4a4a4a;
487
  }
488
 
489
  /* Random Model Button */
490
  .random-model-btn {
491
  display: flex;
492
  align-items: center;
493
+ justify-content: center;
494
+ padding: 0.625rem 1.25rem;
495
+ background: #2d2d2d;
496
  color: white;
497
  border: none;
498
  border-radius: 4px;
499
  cursor: pointer;
500
  font-size: 0.9rem;
501
  font-family: 'Instrument Sans', sans-serif;
502
+ font-weight: 600;
503
  transition: all 0.2s;
504
  width: 100%;
 
505
  }
506
 
507
  .random-model-btn:hover:not(:disabled) {
508
+ background: #1a1a1a;
509
  transform: translateY(-1px);
510
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.2);
511
  }
512
 
513
  .random-model-btn:disabled {
 
515
  cursor: not-allowed;
516
  }
517
 
 
 
 
 
518
  /* Zoom Slider */
519
  .zoom-slider-container {
520
  margin-bottom: 1rem;
 
742
  width: 18px;
743
  height: 18px;
744
  cursor: pointer;
745
+ accent-color: #4a4a4a;
746
  margin-right: 0.5rem;
747
  }
748
 
frontend/src/App.tsx CHANGED
@@ -506,28 +506,24 @@ function App() {
506
  alignItems: 'center',
507
  marginBottom: '1.5rem',
508
  paddingBottom: '1rem',
509
- borderBottom: '2px solid #e8e8e8'
510
  }}>
511
  <h2 style={{
512
  margin: 0,
513
  fontSize: '1.5rem',
514
- fontWeight: '700',
515
- background: 'linear-gradient(135deg, #5e35b1 0%, #7b1fa2 100%)',
516
- WebkitBackgroundClip: 'text',
517
- WebkitTextFillColor: 'transparent',
518
- backgroundClip: 'text'
519
  }}>
520
  Filters & Controls
521
  </h2>
522
  {activeFilterCount > 0 && (
523
  <div style={{
524
  fontSize: '0.75rem',
525
- background: 'linear-gradient(135deg, #5e35b1 0%, #7b1fa2 100%)',
526
  color: 'white',
527
- padding: '0.4rem 0.75rem',
528
- borderRadius: '16px',
529
- fontWeight: '600',
530
- boxShadow: '0 2px 6px rgba(94, 53, 177, 0.3)'
531
  }}>
532
  {activeFilterCount} active
533
  </div>
@@ -537,40 +533,40 @@ function App() {
537
  {/* Filter Results Count */}
538
  {!loading && data.length > 0 && (
539
  <div className="sidebar-section" style={{
540
- background: 'linear-gradient(135deg, #f3e5f5 0%, #e1bee7 100%)',
541
- border: '2px solid #ce93d8',
542
  fontSize: '0.9rem',
543
  marginBottom: '1.5rem'
544
  }}>
545
  <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
546
  <div>
547
- <strong style={{ fontSize: '1.1rem', color: '#6a1b9a' }}>
548
  {data.length.toLocaleString()}
549
  </strong>
550
- <span style={{ marginLeft: '0.4rem', color: '#4a148c' }}>
551
  {data.length === 1 ? 'model' : 'models'}
552
  </span>
553
  </div>
554
  {embeddingType === 'graph-aware' && (
555
  <span style={{
556
  fontSize: '0.7rem',
557
- background: '#7b1fa2',
558
  color: 'white',
559
  padding: '0.3rem 0.6rem',
560
  borderRadius: '12px',
561
  fontWeight: '600'
562
  }}>
563
- 🌐 Graph
564
  </span>
565
  )}
566
  </div>
567
  {filteredCount !== null && filteredCount !== data.length && (
568
- <div style={{ fontSize: '0.8rem', color: '#6a1b9a', marginTop: '0.25rem' }}>
569
  of {filteredCount.toLocaleString()} matching
570
  </div>
571
  )}
572
  {stats && filteredCount !== null && filteredCount < stats.total_models && (
573
- <div style={{ fontSize: '0.75rem', color: '#8e24aa', marginTop: '0.25rem' }}>
574
  from {stats.total_models.toLocaleString()} total
575
  </div>
576
  )}
@@ -579,15 +575,7 @@ function App() {
579
 
580
  {/* Search Section */}
581
  <div className="sidebar-section">
582
- <h3 style={{
583
- display: 'flex',
584
- alignItems: 'center',
585
- gap: '0.5rem',
586
- color: '#5e35b1',
587
- marginBottom: '0.75rem'
588
- }}>
589
- πŸ” Search Models
590
- </h3>
591
  <input
592
  type="text"
593
  value={searchQuery}
@@ -602,14 +590,7 @@ function App() {
602
 
603
  {/* Popularity Filters */}
604
  <div className="sidebar-section">
605
- <h3 style={{
606
- display: 'flex',
607
- alignItems: 'center',
608
- gap: '0.5rem',
609
- color: '#5e35b1'
610
- }}>
611
- πŸ“Š Popularity Filters
612
- </h3>
613
 
614
  <label style={{ marginBottom: '1rem', display: 'block' }}>
615
  <div style={{ display: 'flex', justifyContent: 'space-between', marginBottom: '0.5rem' }}>
@@ -706,14 +687,7 @@ function App() {
706
 
707
  {/* Discovery */}
708
  <div className="sidebar-section">
709
- <h3 style={{
710
- display: 'flex',
711
- alignItems: 'center',
712
- gap: '0.5rem',
713
- color: '#5e35b1'
714
- }}>
715
- 🎲 Discovery
716
- </h3>
717
  <RandomModelButton
718
  data={data}
719
  onSelect={(model: ModelPoint) => {
@@ -727,15 +701,7 @@ function App() {
727
  {/* Visualization Options */}
728
  <div className="sidebar-section">
729
  <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '1rem' }}>
730
- <h3 style={{
731
- margin: 0,
732
- display: 'flex',
733
- alignItems: 'center',
734
- gap: '0.5rem',
735
- color: '#5e35b1'
736
- }}>
737
- 🎨 Visualization
738
- </h3>
739
  <ThemeToggle />
740
  </div>
741
 
@@ -862,10 +828,10 @@ function App() {
862
  </select>
863
  </label>
864
 
865
- <div className="sidebar-section" style={{ background: '#fff3cd', borderColor: '#ffc107', marginBottom: '1rem', padding: '0.75rem', borderRadius: '4px', border: '1px solid' }}>
866
  <label style={{ display: 'block', marginBottom: '0' }}>
867
- <span style={{ fontWeight: '600', display: 'block', marginBottom: '0.5rem', color: '#856404' }}>
868
- βš™οΈ Projection Method
869
  </span>
870
  <select
871
  value={projectionMethod}
@@ -875,7 +841,7 @@ function App() {
875
  <option value="umap">UMAP (better global structure)</option>
876
  <option value="tsne">t-SNE (better local clusters)</option>
877
  </select>
878
- <div style={{ fontSize: '0.75rem', color: '#856404', marginTop: '0.5rem', lineHeight: '1.4' }}>
879
  <strong>UMAP:</strong> Preserves global structure, better for exploring relationships<br/>
880
  <strong>t-SNE:</strong> Emphasizes local clusters, better for finding groups
881
  </div>
@@ -884,15 +850,8 @@ function App() {
884
  </div>
885
 
886
  {/* View Modes */}
887
- <div className="sidebar-section" style={{ background: 'linear-gradient(135deg, #f3e5f5 0%, #fce4ec 100%)', border: '2px solid #f48fb1' }}>
888
- <h3 style={{
889
- display: 'flex',
890
- alignItems: 'center',
891
- gap: '0.5rem',
892
- color: '#5e35b1'
893
- }}>
894
- ⚑ View Modes
895
- </h3>
896
 
897
  <label style={{ marginBottom: '1rem', display: 'flex', alignItems: 'center', cursor: 'pointer' }}>
898
  <input
@@ -937,7 +896,7 @@ function App() {
937
  style={{ marginRight: '0.5rem', cursor: 'pointer' }}
938
  />
939
  <div>
940
- <span style={{ fontWeight: '500' }}>🌐 Graph-Aware Embeddings</span>
941
  <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
942
  Use embeddings that respect family tree structure. Models in the same family will be closer together.
943
  </div>
@@ -955,11 +914,11 @@ function App() {
955
  color: '#666'
956
  }}>
957
  <div style={{ display: 'flex', alignItems: 'center', gap: '0.5rem', marginBottom: '0.25rem' }}>
958
- <strong style={{ color: embeddingType === 'graph-aware' ? '#2e7d32' : '#666' }}>
959
- {embeddingType === 'graph-aware' ? '🌐 Graph-Aware' : 'πŸ“ Text-Only'} Embeddings
960
  </strong>
961
  </div>
962
- <div style={{ fontSize: '0.7rem', color: '#888', lineHeight: '1.4' }}>
963
  {embeddingType === 'graph-aware'
964
  ? 'Models in the same family tree are positioned closer together, revealing hierarchical relationships.'
965
  : 'Standard text-based embeddings showing semantic similarity from model descriptions and tags.'}
@@ -1006,15 +965,8 @@ function App() {
1006
 
1007
  {/* Structural Visualization Options */}
1008
  {viewMode === '3d' && (
1009
- <div className="sidebar-section" style={{ background: 'linear-gradient(135deg, #e8f5e9 0%, #f1f8e9 100%)', border: '2px solid #aed581' }}>
1010
- <h3 style={{
1011
- display: 'flex',
1012
- alignItems: 'center',
1013
- gap: '0.5rem',
1014
- color: '#5e35b1'
1015
- }}>
1016
- πŸ”— Network Structure
1017
- </h3>
1018
  <div style={{ fontSize: '0.75rem', color: '#666', marginBottom: '1rem', lineHeight: '1.4' }}>
1019
  Explore relationships and structure in the model ecosystem
1020
  </div>
@@ -1026,12 +978,12 @@ function App() {
1026
  onChange={(e) => setOverviewMode(e.target.checked)}
1027
  style={{ marginRight: '0.5rem', cursor: 'pointer' }}
1028
  />
1029
- <div>
1030
- <span style={{ fontWeight: '500' }}>πŸ” Overview Mode</span>
1031
- <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
1032
- Zoom out to see full ecosystem structure with all relationships visible. Camera will automatically adjust.
1033
- </div>
1034
  </div>
 
1035
  </label>
1036
 
1037
  <label style={{ marginBottom: '1rem', display: 'flex', alignItems: 'center', cursor: 'pointer' }}>
@@ -1041,12 +993,12 @@ function App() {
1041
  onChange={(e) => setShowNetworkEdges(e.target.checked)}
1042
  style={{ marginRight: '0.5rem', cursor: 'pointer' }}
1043
  />
1044
- <div>
1045
- <span style={{ fontWeight: '500' }}>🌐 Network Relationships</span>
1046
- <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
1047
- Show connections between related models (same library, pipeline, or tags). Blue = library, Pink = pipeline.
1048
- </div>
1049
  </div>
 
1050
  </label>
1051
 
1052
  {showNetworkEdges && (
@@ -1073,26 +1025,19 @@ function App() {
1073
  onChange={(e) => setShowStructuralGroups(e.target.checked)}
1074
  style={{ marginRight: '0.5rem', cursor: 'pointer' }}
1075
  />
1076
- <div>
1077
- <span style={{ fontWeight: '500' }}>πŸ“¦ Structural Groupings</span>
1078
- <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
1079
- Highlight clusters and groups with wireframe boundaries. Shows top library and pipeline clusters.
1080
- </div>
1081
  </div>
 
1082
  </label>
1083
  </div>
1084
  )}
1085
 
1086
  {/* Quick Filters */}
1087
  <div className="sidebar-section">
1088
- <h3 style={{
1089
- display: 'flex',
1090
- alignItems: 'center',
1091
- gap: '0.5rem',
1092
- color: '#5e35b1'
1093
- }}>
1094
- ⚑ Quick Actions
1095
- </h3>
1096
  <div style={{ display: 'flex', flexWrap: 'wrap', gap: '0.5rem' }}>
1097
  <button
1098
  onClick={() => {
@@ -1137,14 +1082,7 @@ function App() {
1137
  </div>
1138
 
1139
  <div className="sidebar-section">
1140
- <h3 style={{
1141
- display: 'flex',
1142
- alignItems: 'center',
1143
- gap: '0.5rem',
1144
- color: '#5e35b1'
1145
- }}>
1146
- 🌳 Hierarchy Navigation
1147
- </h3>
1148
  <label style={{ marginBottom: '1rem', display: 'block' }}>
1149
  <span style={{ fontWeight: '500', display: 'block', marginBottom: '0.5rem' }}>
1150
  Max Hierarchy Depth
@@ -1224,14 +1162,7 @@ function App() {
1224
  </div>
1225
 
1226
  <div className="sidebar-section">
1227
- <h3 style={{
1228
- display: 'flex',
1229
- alignItems: 'center',
1230
- gap: '0.5rem',
1231
- color: '#5e35b1'
1232
- }}>
1233
- πŸ‘₯ Family Tree Explorer
1234
- </h3>
1235
  <div style={{ position: 'relative' }}>
1236
  <input
1237
  type="text"
 
506
  alignItems: 'center',
507
  marginBottom: '1.5rem',
508
  paddingBottom: '1rem',
509
+ borderBottom: '1px solid #e0e0e0'
510
  }}>
511
  <h2 style={{
512
  margin: 0,
513
  fontSize: '1.5rem',
514
+ fontWeight: '600',
515
+ color: '#2d2d2d'
 
 
 
516
  }}>
517
  Filters & Controls
518
  </h2>
519
  {activeFilterCount > 0 && (
520
  <div style={{
521
  fontSize: '0.75rem',
522
+ background: '#4a4a4a',
523
  color: 'white',
524
+ padding: '0.35rem 0.7rem',
525
+ borderRadius: '12px',
526
+ fontWeight: '600'
 
527
  }}>
528
  {activeFilterCount} active
529
  </div>
 
533
  {/* Filter Results Count */}
534
  {!loading && data.length > 0 && (
535
  <div className="sidebar-section" style={{
536
+ background: '#f5f5f5',
537
+ border: '1px solid #d0d0d0',
538
  fontSize: '0.9rem',
539
  marginBottom: '1.5rem'
540
  }}>
541
  <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
542
  <div>
543
+ <strong style={{ fontSize: '1.1rem', color: '#2d2d2d' }}>
544
  {data.length.toLocaleString()}
545
  </strong>
546
+ <span style={{ marginLeft: '0.4rem', color: '#4a4a4a' }}>
547
  {data.length === 1 ? 'model' : 'models'}
548
  </span>
549
  </div>
550
  {embeddingType === 'graph-aware' && (
551
  <span style={{
552
  fontSize: '0.7rem',
553
+ background: '#4a4a4a',
554
  color: 'white',
555
  padding: '0.3rem 0.6rem',
556
  borderRadius: '12px',
557
  fontWeight: '600'
558
  }}>
559
+ Graph
560
  </span>
561
  )}
562
  </div>
563
  {filteredCount !== null && filteredCount !== data.length && (
564
+ <div style={{ fontSize: '0.8rem', color: '#666', marginTop: '0.25rem' }}>
565
  of {filteredCount.toLocaleString()} matching
566
  </div>
567
  )}
568
  {stats && filteredCount !== null && filteredCount < stats.total_models && (
569
+ <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
570
  from {stats.total_models.toLocaleString()} total
571
  </div>
572
  )}
 
575
 
576
  {/* Search Section */}
577
  <div className="sidebar-section">
578
+ <h3>Search Models</h3>
 
 
 
 
 
 
 
 
579
  <input
580
  type="text"
581
  value={searchQuery}
 
590
 
591
  {/* Popularity Filters */}
592
  <div className="sidebar-section">
593
+ <h3>Popularity Filters</h3>
 
 
 
 
 
 
 
594
 
595
  <label style={{ marginBottom: '1rem', display: 'block' }}>
596
  <div style={{ display: 'flex', justifyContent: 'space-between', marginBottom: '0.5rem' }}>
 
687
 
688
  {/* Discovery */}
689
  <div className="sidebar-section">
690
+ <h3>Discovery</h3>
 
 
 
 
 
 
 
691
  <RandomModelButton
692
  data={data}
693
  onSelect={(model: ModelPoint) => {
 
701
  {/* Visualization Options */}
702
  <div className="sidebar-section">
703
  <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '1rem' }}>
704
+ <h3 style={{ margin: 0 }}>Visualization Options</h3>
 
 
 
 
 
 
 
 
705
  <ThemeToggle />
706
  </div>
707
 
 
828
  </select>
829
  </label>
830
 
831
+ <div className="sidebar-section" style={{ background: '#f5f5f5', borderColor: '#d0d0d0', marginBottom: '1rem', padding: '0.75rem', borderRadius: '4px', border: '1px solid' }}>
832
  <label style={{ display: 'block', marginBottom: '0' }}>
833
+ <span style={{ fontWeight: '600', display: 'block', marginBottom: '0.5rem', color: '#2d2d2d' }}>
834
+ Projection Method
835
  </span>
836
  <select
837
  value={projectionMethod}
 
841
  <option value="umap">UMAP (better global structure)</option>
842
  <option value="tsne">t-SNE (better local clusters)</option>
843
  </select>
844
+ <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.5rem', lineHeight: '1.4' }}>
845
  <strong>UMAP:</strong> Preserves global structure, better for exploring relationships<br/>
846
  <strong>t-SNE:</strong> Emphasizes local clusters, better for finding groups
847
  </div>
 
850
  </div>
851
 
852
  {/* View Modes */}
853
+ <div className="sidebar-section">
854
+ <h3>View Modes</h3>
 
 
 
 
 
 
 
855
 
856
  <label style={{ marginBottom: '1rem', display: 'flex', alignItems: 'center', cursor: 'pointer' }}>
857
  <input
 
896
  style={{ marginRight: '0.5rem', cursor: 'pointer' }}
897
  />
898
  <div>
899
+ <span style={{ fontWeight: '500' }}>Graph-Aware Embeddings</span>
900
  <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
901
  Use embeddings that respect family tree structure. Models in the same family will be closer together.
902
  </div>
 
914
  color: '#666'
915
  }}>
916
  <div style={{ display: 'flex', alignItems: 'center', gap: '0.5rem', marginBottom: '0.25rem' }}>
917
+ <strong style={{ color: '#2d2d2d' }}>
918
+ {embeddingType === 'graph-aware' ? 'Graph-Aware' : 'Text-Only'} Embeddings
919
  </strong>
920
  </div>
921
+ <div style={{ fontSize: '0.7rem', color: '#666', lineHeight: '1.4' }}>
922
  {embeddingType === 'graph-aware'
923
  ? 'Models in the same family tree are positioned closer together, revealing hierarchical relationships.'
924
  : 'Standard text-based embeddings showing semantic similarity from model descriptions and tags.'}
 
965
 
966
  {/* Structural Visualization Options */}
967
  {viewMode === '3d' && (
968
+ <div className="sidebar-section">
969
+ <h3>Network Structure</h3>
 
 
 
 
 
 
 
970
  <div style={{ fontSize: '0.75rem', color: '#666', marginBottom: '1rem', lineHeight: '1.4' }}>
971
  Explore relationships and structure in the model ecosystem
972
  </div>
 
978
  onChange={(e) => setOverviewMode(e.target.checked)}
979
  style={{ marginRight: '0.5rem', cursor: 'pointer' }}
980
  />
981
+ <div>
982
+ <span style={{ fontWeight: '500' }}>Overview Mode</span>
983
+ <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
984
+ Zoom out to see full ecosystem structure with all relationships visible. Camera will automatically adjust.
 
985
  </div>
986
+ </div>
987
  </label>
988
 
989
  <label style={{ marginBottom: '1rem', display: 'flex', alignItems: 'center', cursor: 'pointer' }}>
 
993
  onChange={(e) => setShowNetworkEdges(e.target.checked)}
994
  style={{ marginRight: '0.5rem', cursor: 'pointer' }}
995
  />
996
+ <div>
997
+ <span style={{ fontWeight: '500' }}>Network Relationships</span>
998
+ <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
999
+ Show connections between related models (same library, pipeline, or tags). Blue = library, Pink = pipeline.
 
1000
  </div>
1001
+ </div>
1002
  </label>
1003
 
1004
  {showNetworkEdges && (
 
1025
  onChange={(e) => setShowStructuralGroups(e.target.checked)}
1026
  style={{ marginRight: '0.5rem', cursor: 'pointer' }}
1027
  />
1028
+ <div>
1029
+ <span style={{ fontWeight: '500' }}>Structural Groupings</span>
1030
+ <div style={{ fontSize: '0.75rem', color: '#666', marginTop: '0.25rem' }}>
1031
+ Highlight clusters and groups with wireframe boundaries. Shows top library and pipeline clusters.
 
1032
  </div>
1033
+ </div>
1034
  </label>
1035
  </div>
1036
  )}
1037
 
1038
  {/* Quick Filters */}
1039
  <div className="sidebar-section">
1040
+ <h3>Quick Actions</h3>
 
 
 
 
 
 
 
1041
  <div style={{ display: 'flex', flexWrap: 'wrap', gap: '0.5rem' }}>
1042
  <button
1043
  onClick={() => {
 
1082
  </div>
1083
 
1084
  <div className="sidebar-section">
1085
+ <h3>Hierarchy Navigation</h3>
 
 
 
 
 
 
 
1086
  <label style={{ marginBottom: '1rem', display: 'block' }}>
1087
  <span style={{ fontWeight: '500', display: 'block', marginBottom: '0.5rem' }}>
1088
  Max Hierarchy Depth
 
1162
  </div>
1163
 
1164
  <div className="sidebar-section">
1165
+ <h3>Family Tree Explorer</h3>
 
 
 
 
 
 
 
1166
  <div style={{ position: 'relative' }}>
1167
  <input
1168
  type="text"
frontend/src/components/PaperPlots.css DELETED
@@ -1,92 +0,0 @@
1
- .paper-plots {
2
- display: flex;
3
- flex-direction: column;
4
- gap: 1rem;
5
- padding: 1rem;
6
- background: white;
7
- border-radius: 8px;
8
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
9
- }
10
-
11
- .plot-selector {
12
- border-bottom: 1px solid #e0e0e0;
13
- padding-bottom: 1rem;
14
- }
15
-
16
- .plot-selector h3 {
17
- margin: 0 0 0.75rem 0;
18
- font-size: 1.25rem;
19
- color: #333;
20
- }
21
-
22
- .plot-buttons {
23
- display: flex;
24
- flex-wrap: wrap;
25
- gap: 0.5rem;
26
- }
27
-
28
- .plot-button {
29
- padding: 0.5rem 1rem;
30
- border: 1px solid #ccc;
31
- background: white;
32
- border-radius: 4px;
33
- cursor: pointer;
34
- font-size: 0.875rem;
35
- transition: all 0.2s;
36
- color: #333;
37
- }
38
-
39
- .plot-button:hover {
40
- background: #f5f5f5;
41
- border-color: #999;
42
- }
43
-
44
- .plot-button.active {
45
- background: #4a90e2;
46
- color: white;
47
- border-color: #4a90e2;
48
- }
49
-
50
- .plot-container {
51
- position: relative;
52
- min-height: 600px;
53
- display: flex;
54
- align-items: center;
55
- justify-content: center;
56
- }
57
-
58
- .plot-loading {
59
- position: absolute;
60
- top: 50%;
61
- left: 50%;
62
- transform: translate(-50%, -50%);
63
- color: #666;
64
- font-size: 1rem;
65
- }
66
-
67
- .plot-tooltip {
68
- position: absolute;
69
- padding: 0.5rem;
70
- background: rgba(0, 0, 0, 0.8);
71
- color: white;
72
- border-radius: 4px;
73
- pointer-events: none;
74
- font-size: 0.875rem;
75
- z-index: 1000;
76
- }
77
-
78
- .plot-container svg {
79
- display: block;
80
- margin: 0 auto;
81
- }
82
-
83
- @media (max-width: 768px) {
84
- .plot-buttons {
85
- flex-direction: column;
86
- }
87
-
88
- .plot-button {
89
- width: 100%;
90
- }
91
- }
92
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/components/PaperPlots.tsx DELETED
@@ -1,755 +0,0 @@
1
- /**
2
- * Interactive D3.js visualizations based on plots from the research paper.
3
- * "Anatomy of a Machine Learning Ecosystem: 2 Million Models on Hugging Face"
4
- */
5
- import React, { useRef, useEffect, useState, useMemo } from 'react';
6
- import * as d3 from 'd3';
7
- import { ModelPoint } from '../types';
8
- import './PaperPlots.css';
9
-
10
- const API_BASE = process.env.REACT_APP_API_URL || 'http://localhost:8000';
11
-
12
- interface PaperPlotsProps {
13
- data: ModelPoint[];
14
- width?: number;
15
- height?: number;
16
- }
17
-
18
- type PlotType = 'family-size' | 'similarity-comparison' | 'license-drift' | 'model-card-length' | 'growth-timeline';
19
-
20
- export default function PaperPlots({ data, width = 800, height = 600 }: PaperPlotsProps) {
21
- const [activePlot, setActivePlot] = useState<PlotType>('family-size');
22
- const familySizeRef = useRef<SVGSVGElement>(null);
23
- const similarityRef = useRef<SVGSVGElement>(null);
24
- const licenseDriftRef = useRef<SVGSVGElement>(null);
25
- const modelCardLengthRef = useRef<SVGSVGElement>(null);
26
- const growthTimelineRef = useRef<SVGSVGElement>(null);
27
- const [familyTreeData, setFamilyTreeData] = useState<any>(null);
28
- const [loading, setLoading] = useState(false);
29
-
30
- // Fetch family tree statistics
31
- useEffect(() => {
32
- const fetchFamilyStats = async () => {
33
- setLoading(true);
34
- try {
35
- const response = await fetch(`${API_BASE}/api/family/stats`);
36
- if (response.ok) {
37
- const stats = await response.json();
38
- setFamilyTreeData(stats);
39
- }
40
- } catch (err) {
41
- console.error('Error fetching family stats:', err);
42
- } finally {
43
- setLoading(false);
44
- }
45
- };
46
- fetchFamilyStats();
47
- }, []);
48
-
49
- // Plot 1: Family Size Distribution
50
- useEffect(() => {
51
- if (activePlot !== 'family-size' || !familySizeRef.current) return;
52
-
53
- const svg = d3.select(familySizeRef.current);
54
- svg.selectAll('*').remove();
55
-
56
- const margin = { top: 40, right: 40, bottom: 60, left: 60 };
57
- const innerWidth = width - margin.left - margin.right;
58
- const innerHeight = height - margin.top - margin.bottom;
59
-
60
- // Use API data if available, otherwise calculate from current data
61
- let binData: Array<{ x0: number; x1: number; count: number }>;
62
-
63
- if (familyTreeData && familyTreeData.family_size_distribution) {
64
- const sizeDist = familyTreeData.family_size_distribution;
65
- const sizes = Object.keys(sizeDist).map(Number);
66
- const counts = Object.values(sizeDist) as number[];
67
-
68
- // Create histogram bins from distribution
69
- const maxSize = d3.max(sizes) || 1;
70
- const bins = d3.bin().thresholds(20).domain([0, maxSize])(sizes);
71
-
72
- binData = bins.map(bin => {
73
- let count = 0;
74
- sizes.forEach((size, i) => {
75
- if (size >= (bin.x0 || 0) && size < (bin.x1 || maxSize)) {
76
- count += counts[i];
77
- }
78
- });
79
- return {
80
- x0: bin.x0 || 0,
81
- x1: bin.x1 || maxSize,
82
- count: count
83
- };
84
- }).filter(d => d.count > 0);
85
- } else {
86
- // Fallback: Calculate from current data
87
- const familySizes = new Map<string, number>();
88
- data.forEach(model => {
89
- const familyKey = model.parent_model || model.model_id;
90
- familySizes.set(familyKey, (familySizes.get(familyKey) || 0) + 1);
91
- });
92
-
93
- const sizes = Array.from(familySizes.values());
94
- const bins = d3.bin().thresholds(20)(sizes);
95
- binData = bins.map(bin => ({
96
- x0: bin.x0 || 0,
97
- x1: bin.x1 || 0,
98
- count: bin.length
99
- }));
100
- }
101
-
102
- const g = svg.append('g')
103
- .attr('transform', `translate(${margin.left},${margin.top})`);
104
-
105
- // Scales
106
- const xScale = d3.scaleLinear()
107
- .domain([0, d3.max(binData, d => d.x1) || 1])
108
- .range([0, innerWidth])
109
- .nice();
110
-
111
- const yScale = d3.scaleLinear()
112
- .domain([0, d3.max(binData, d => d.count) || 1])
113
- .range([innerHeight, 0])
114
- .nice();
115
-
116
- // Bars
117
- g.selectAll('rect')
118
- .data(binData)
119
- .enter()
120
- .append('rect')
121
- .attr('x', d => xScale(d.x0))
122
- .attr('width', d => Math.max(0, xScale(d.x1) - xScale(d.x0) - 1))
123
- .attr('y', d => yScale(d.count))
124
- .attr('height', d => innerHeight - yScale(d.count))
125
- .attr('fill', '#4a90e2')
126
- .attr('opacity', 0.7)
127
- .on('mouseover', function(event, d) {
128
- d3.select(this).attr('opacity', 1);
129
- const tooltip = d3.select('body').append('div')
130
- .attr('class', 'plot-tooltip')
131
- .style('opacity', 0);
132
- tooltip.transition().duration(200).style('opacity', 0.9);
133
- tooltip.html(`Family Size: ${d.x0.toFixed(0)}-${d.x1.toFixed(0)}<br/>Count: ${d.count}`)
134
- .style('left', (event.pageX + 10) + 'px')
135
- .style('top', (event.pageY - 28) + 'px');
136
- })
137
- .on('mouseout', function() {
138
- d3.select(this).attr('opacity', 0.7);
139
- d3.selectAll('.plot-tooltip').remove();
140
- });
141
-
142
- // Axes
143
- const xAxis = d3.axisBottom(xScale).tickFormat(d3.format('d'));
144
- const yAxis = d3.axisLeft(yScale);
145
-
146
- g.append('g')
147
- .attr('transform', `translate(0,${innerHeight})`)
148
- .call(xAxis)
149
- .append('text')
150
- .attr('x', innerWidth / 2)
151
- .attr('y', 45)
152
- .attr('fill', 'currentColor')
153
- .style('text-anchor', 'middle')
154
- .style('font-size', '14px')
155
- .text('Family Size (number of models)');
156
-
157
- g.append('g')
158
- .call(yAxis)
159
- .append('text')
160
- .attr('transform', 'rotate(-90)')
161
- .attr('y', -45)
162
- .attr('x', -innerHeight / 2)
163
- .attr('fill', 'currentColor')
164
- .style('text-anchor', 'middle')
165
- .style('font-size', '14px')
166
- .text('Number of Families');
167
-
168
- // Title
169
- svg.append('text')
170
- .attr('x', width / 2)
171
- .attr('y', 20)
172
- .attr('text-anchor', 'middle')
173
- .style('font-size', '16px')
174
- .style('font-weight', 'bold')
175
- .text('Family Size Distribution');
176
-
177
- }, [activePlot, data, width, height, familyTreeData]);
178
-
179
- // Plot 2: Similarity Comparison (Sibling vs Parent-Child)
180
- useEffect(() => {
181
- if (activePlot !== 'similarity-comparison' || !similarityRef.current || !data.length) return;
182
-
183
- const svg = d3.select(similarityRef.current);
184
- svg.selectAll('*').remove();
185
-
186
- const margin = { top: 40, right: 40, bottom: 60, left: 60 };
187
- const innerWidth = width - margin.left - margin.right;
188
- const innerHeight = height - margin.top - margin.bottom;
189
-
190
- // This would require similarity data - for now, create a placeholder visualization
191
- // In the paper, this shows that siblings are more similar than parent-child pairs
192
- const g = svg.append('g')
193
- .attr('transform', `translate(${margin.left},${margin.top})`);
194
-
195
- // Placeholder: Box plot or violin plot showing similarity distributions
196
- // Sibling similarity (higher)
197
- const siblingData = Array.from({ length: 100 }, () => 0.6 + Math.random() * 0.3);
198
- // Parent-child similarity (lower)
199
- const parentChildData = Array.from({ length: 100 }, () => 0.3 + Math.random() * 0.3);
200
-
201
- const xScale = d3.scaleBand()
202
- .domain(['Sibling Pairs', 'Parent-Child Pairs'])
203
- .range([0, innerWidth])
204
- .padding(0.3);
205
-
206
- const yScale = d3.scaleLinear()
207
- .domain([0, 1])
208
- .range([innerHeight, 0])
209
- .nice();
210
-
211
- // Box plot visualization
212
- [siblingData, parentChildData].forEach((dataset, i) => {
213
- const label = i === 0 ? 'Sibling Pairs' : 'Parent-Child Pairs';
214
- const x = xScale(label);
215
- const bandWidth = xScale.bandwidth();
216
-
217
- if (x === undefined) return;
218
-
219
- // Calculate quartiles
220
- const sorted = dataset.sort((a, b) => a - b);
221
- const q1 = d3.quantile(sorted, 0.25) || 0;
222
- const q2 = d3.quantile(sorted, 0.5) || 0;
223
- const q3 = d3.quantile(sorted, 0.75) || 0;
224
- const min = sorted[0];
225
- const max = sorted[sorted.length - 1];
226
-
227
- // Box
228
- g.append('rect')
229
- .attr('x', x)
230
- .attr('y', yScale(q3))
231
- .attr('width', bandWidth)
232
- .attr('height', yScale(q1) - yScale(q3))
233
- .attr('fill', i === 0 ? '#4a90e2' : '#e24a4a')
234
- .attr('opacity', 0.6)
235
- .attr('stroke', '#333')
236
- .attr('stroke-width', 1);
237
-
238
- // Median line
239
- g.append('line')
240
- .attr('x1', x)
241
- .attr('x2', x + bandWidth)
242
- .attr('y1', yScale(q2))
243
- .attr('y2', yScale(q2))
244
- .attr('stroke', '#333')
245
- .attr('stroke-width', 2);
246
-
247
- // Whiskers
248
- g.append('line')
249
- .attr('x1', x + bandWidth / 2)
250
- .attr('x2', x + bandWidth / 2)
251
- .attr('y1', yScale(min))
252
- .attr('y2', yScale(q1))
253
- .attr('stroke', '#333')
254
- .attr('stroke-width', 1);
255
-
256
- g.append('line')
257
- .attr('x1', x + bandWidth / 2)
258
- .attr('x2', x + bandWidth / 2)
259
- .attr('y1', yScale(q3))
260
- .attr('y2', yScale(max))
261
- .attr('stroke', '#333')
262
- .attr('stroke-width', 1);
263
-
264
- // Min/Max lines
265
- g.append('line')
266
- .attr('x1', x + bandWidth * 0.25)
267
- .attr('x2', x + bandWidth * 0.75)
268
- .attr('y1', yScale(min))
269
- .attr('y2', yScale(min))
270
- .attr('stroke', '#333')
271
- .attr('stroke-width', 1);
272
-
273
- g.append('line')
274
- .attr('x1', x + bandWidth * 0.25)
275
- .attr('x2', x + bandWidth * 0.75)
276
- .attr('y1', yScale(max))
277
- .attr('y2', yScale(max))
278
- .attr('stroke', '#333')
279
- .attr('stroke-width', 1);
280
- });
281
-
282
- // Axes
283
- const yAxis = d3.axisLeft(yScale);
284
- g.append('g').call(yAxis);
285
-
286
- g.append('text')
287
- .attr('transform', 'rotate(-90)')
288
- .attr('y', -45)
289
- .attr('x', -innerHeight / 2)
290
- .attr('fill', 'currentColor')
291
- .style('text-anchor', 'middle')
292
- .style('font-size', '14px')
293
- .text('Similarity Score');
294
-
295
- // Title
296
- svg.append('text')
297
- .attr('x', width / 2)
298
- .attr('y', 20)
299
- .attr('text-anchor', 'middle')
300
- .style('font-size', '16px')
301
- .style('font-weight', 'bold')
302
- .text('Similarity: Siblings vs Parent-Child Pairs');
303
-
304
- }, [activePlot, data, width, height, familyTreeData]);
305
-
306
- // Plot 3: License Drift (over family depth)
307
- useEffect(() => {
308
- if (activePlot !== 'license-drift' || !licenseDriftRef.current || !data.length) return;
309
-
310
- const svg = d3.select(licenseDriftRef.current);
311
- svg.selectAll('*').remove();
312
-
313
- const margin = { top: 40, right: 40, bottom: 60, left: 80 };
314
- const innerWidth = width - margin.left - margin.right;
315
- const innerHeight = height - margin.top - margin.bottom;
316
-
317
- // Group by family depth and license type
318
- const depthGroups = new Map<number, Map<string, number>>();
319
- data.forEach(model => {
320
- const depth = model.family_depth || 0;
321
- const license = model.licenses ? (model.licenses.split(',')[0].trim() || 'unknown') : 'unknown';
322
-
323
- if (!depthGroups.has(depth)) {
324
- depthGroups.set(depth, new Map());
325
- }
326
- const licenseMap = depthGroups.get(depth)!;
327
- licenseMap.set(license, (licenseMap.get(license) || 0) + 1);
328
- });
329
-
330
- const depths = Array.from(depthGroups.keys()).sort((a, b) => a - b);
331
- const allLicenses = new Set<string>();
332
- depthGroups.forEach(licenseMap => {
333
- licenseMap.forEach((_, license) => allLicenses.add(license));
334
- });
335
-
336
- const licenseTypes = Array.from(allLicenses).slice(0, 5); // Top 5 licenses
337
- const colorScale = d3.scaleOrdinal(d3.schemeCategory10).domain(licenseTypes);
338
-
339
- const g = svg.append('g')
340
- .attr('transform', `translate(${margin.left},${margin.top})`);
341
-
342
- const xScale = d3.scaleBand()
343
- .domain(depths.map(d => d.toString()))
344
- .range([0, innerWidth])
345
- .padding(0.1);
346
-
347
- const yScale = d3.scaleLinear()
348
- .domain([0, 1])
349
- .range([innerHeight, 0]);
350
-
351
- // Stacked area or bars showing license distribution
352
- licenseTypes.forEach((license, i) => {
353
- const stack = depths.map(depth => {
354
- const licenseMap = depthGroups.get(depth) || new Map();
355
- const total = Array.from(licenseMap.values()).reduce((a, b) => a + b, 0);
356
- const count = licenseMap.get(license) || 0;
357
- return { depth, proportion: total > 0 ? count / total : 0 };
358
- });
359
-
360
- // Draw as line chart showing proportion over depth
361
- const line = d3.line<{ depth: number; proportion: number }>()
362
- .x(d => (xScale(d.depth.toString()) || 0) + xScale.bandwidth() / 2)
363
- .y(d => yScale(d.proportion))
364
- .curve(d3.curveMonotoneX);
365
-
366
- g.append('path')
367
- .datum(stack)
368
- .attr('fill', 'none')
369
- .attr('stroke', colorScale(license))
370
- .attr('stroke-width', 2)
371
- .attr('d', line);
372
-
373
- // Add circles for data points
374
- g.selectAll(`.dot-${i}`)
375
- .data(stack)
376
- .enter()
377
- .append('circle')
378
- .attr('cx', d => (xScale(d.depth.toString()) || 0) + xScale.bandwidth() / 2)
379
- .attr('cy', d => yScale(d.proportion))
380
- .attr('r', 4)
381
- .attr('fill', colorScale(license));
382
- });
383
-
384
- // Axes
385
- const xAxis = d3.axisBottom(xScale);
386
- const yAxis = d3.axisLeft(yScale).tickFormat(d3.format('.0%'));
387
-
388
- g.append('g')
389
- .attr('transform', `translate(0,${innerHeight})`)
390
- .call(xAxis)
391
- .append('text')
392
- .attr('x', innerWidth / 2)
393
- .attr('y', 45)
394
- .attr('fill', 'currentColor')
395
- .style('text-anchor', 'middle')
396
- .style('font-size', '14px')
397
- .text('Family Depth (generation)');
398
-
399
- g.append('g').call(yAxis)
400
- .append('text')
401
- .attr('transform', 'rotate(-90)')
402
- .attr('y', -60)
403
- .attr('x', -innerHeight / 2)
404
- .attr('fill', 'currentColor')
405
- .style('text-anchor', 'middle')
406
- .style('font-size', '14px')
407
- .text('Proportion of Models');
408
-
409
- // Legend
410
- const legend = g.append('g')
411
- .attr('transform', `translate(${innerWidth - 150}, 20)`);
412
-
413
- licenseTypes.forEach((license, i) => {
414
- const legendRow = legend.append('g')
415
- .attr('transform', `translate(0, ${i * 20})`);
416
-
417
- legendRow.append('rect')
418
- .attr('width', 15)
419
- .attr('height', 15)
420
- .attr('fill', colorScale(license));
421
-
422
- legendRow.append('text')
423
- .attr('x', 20)
424
- .attr('y', 12)
425
- .style('font-size', '12px')
426
- .text(license.length > 15 ? license.substring(0, 15) + '...' : license);
427
- });
428
-
429
- // Title
430
- svg.append('text')
431
- .attr('x', width / 2)
432
- .attr('y', 20)
433
- .attr('text-anchor', 'middle')
434
- .style('font-size', '16px')
435
- .style('font-weight', 'bold')
436
- .text('License Distribution Across Family Generations');
437
-
438
- }, [activePlot, data, width, height, familyTreeData]);
439
-
440
- // Plot 4: Model Card Length Distribution
441
- useEffect(() => {
442
- if (activePlot !== 'model-card-length' || !modelCardLengthRef.current || !data.length) return;
443
-
444
- const svg = d3.select(modelCardLengthRef.current);
445
- svg.selectAll('*').remove();
446
-
447
- const margin = { top: 40, right: 40, bottom: 60, left: 60 };
448
- const innerWidth = width - margin.left - margin.right;
449
- const innerHeight = height - margin.top - margin.bottom;
450
-
451
- // Placeholder: Would need model card length data
452
- // In the paper, this shows model cards getting shorter and more standardized
453
- const g = svg.append('g')
454
- .attr('transform', `translate(${margin.left},${margin.top})`);
455
-
456
- // Use real model card length data from API if available
457
- let depthData = new Map<number, number[]>();
458
-
459
- if (familyTreeData && familyTreeData.model_card_length_by_depth) {
460
- // Use real data from API
461
- const cardStats = familyTreeData.model_card_length_by_depth;
462
- Object.keys(cardStats).forEach(depthStr => {
463
- const depth = parseInt(depthStr);
464
- const stats = cardStats[depthStr];
465
- // Create synthetic distribution from stats (mean, q1, q3)
466
- const lengths: number[] = [];
467
- const count = Math.min(stats.count, 100); // Limit for performance
468
- for (let i = 0; i < count; i++) {
469
- // Generate values around the mean with spread based on quartiles
470
- const spread = (stats.q3 - stats.q1) / 2;
471
- const length = stats.mean + (Math.random() - 0.5) * spread * 2;
472
- lengths.push(Math.max(0, length));
473
- }
474
- depthData.set(depth, lengths);
475
- });
476
- } else {
477
- // Fallback: Calculate from current data
478
- const depthGroups = new Map<number, number[]>();
479
- data.forEach(model => {
480
- const depth = model.family_depth || 0;
481
- // We don't have model card length in ModelPoint, so use placeholder
482
- // In a real implementation, this would come from the API
483
- if (!depthGroups.has(depth)) {
484
- depthGroups.set(depth, []);
485
- }
486
- });
487
- depthData = depthGroups;
488
- }
489
-
490
- // If still no data, use simulated data
491
- if (depthData.size === 0) {
492
- for (let depth = 0; depth <= 5; depth++) {
493
- const lengths = Array.from({ length: 50 }, () => {
494
- const baseLength = 2000 - depth * 200;
495
- return baseLength + (Math.random() - 0.5) * 500;
496
- });
497
- depthData.set(depth, lengths);
498
- }
499
- }
500
-
501
- const depths = Array.from(depthData.keys()).sort((a, b) => a - b);
502
- const maxDepth = d3.max(depths) || 5;
503
- const allLengths = Array.from(depthData.values()).flat();
504
- const maxLength = d3.max(allLengths) || 3000;
505
-
506
- const xScale = d3.scaleBand()
507
- .domain(depths.map(d => d.toString()))
508
- .range([0, innerWidth])
509
- .padding(0.2);
510
-
511
- const yScale = d3.scaleLinear()
512
- .domain([0, maxLength])
513
- .range([innerHeight, 0])
514
- .nice();
515
-
516
- // Violin plot or box plot
517
- depthData.forEach((lengths, depth) => {
518
- const x = xScale(depth.toString());
519
- const bandWidth = xScale.bandwidth();
520
-
521
- if (x === undefined) return;
522
-
523
- // Simple box plot
524
- const sorted = lengths.sort((a, b) => a - b);
525
- const q1 = d3.quantile(sorted, 0.25) || 0;
526
- const q2 = d3.quantile(sorted, 0.5) || 0;
527
- const q3 = d3.quantile(sorted, 0.75) || 0;
528
-
529
- g.append('rect')
530
- .attr('x', x)
531
- .attr('y', yScale(q3))
532
- .attr('width', bandWidth)
533
- .attr('height', yScale(q1) - yScale(q3))
534
- .attr('fill', '#4a90e2')
535
- .attr('opacity', 0.6)
536
- .attr('stroke', '#333');
537
-
538
- g.append('line')
539
- .attr('x1', x)
540
- .attr('x2', x + bandWidth)
541
- .attr('y1', yScale(q2))
542
- .attr('y2', yScale(q2))
543
- .attr('stroke', '#333')
544
- .attr('stroke-width', 2);
545
- });
546
-
547
- const yAxis = d3.axisLeft(yScale);
548
- g.append('g').call(yAxis)
549
- .append('text')
550
- .attr('transform', 'rotate(-90)')
551
- .attr('y', -45)
552
- .attr('x', -innerHeight / 2)
553
- .attr('fill', 'currentColor')
554
- .style('text-anchor', 'middle')
555
- .style('font-size', '14px')
556
- .text('Model Card Length (characters)');
557
-
558
- // Title
559
- svg.append('text')
560
- .attr('x', width / 2)
561
- .attr('y', 20)
562
- .attr('text-anchor', 'middle')
563
- .style('font-size', '16px')
564
- .style('font-weight', 'bold')
565
- .text('Model Card Length by Family Generation');
566
-
567
- }, [activePlot, data, width, height, familyTreeData]);
568
-
569
- // Plot 5: Growth Timeline
570
- useEffect(() => {
571
- if (activePlot !== 'growth-timeline' || !growthTimelineRef.current) return;
572
-
573
- const svg = d3.select(growthTimelineRef.current);
574
- svg.selectAll('*').remove();
575
-
576
- const margin = { top: 40, right: 40, bottom: 60, left: 60 };
577
- const innerWidth = width - margin.left - margin.right;
578
- const innerHeight = height - margin.top - margin.bottom;
579
-
580
- // Fetch growth data from model tracker API
581
- fetch(`${API_BASE}/api/model-count/historical?days=365`)
582
- .then(res => res.json())
583
- .then(data => {
584
- if (!data.counts || data.counts.length === 0) {
585
- svg.append('text')
586
- .attr('x', width / 2)
587
- .attr('y', height / 2)
588
- .attr('text-anchor', 'middle')
589
- .text('No historical data available');
590
- return;
591
- }
592
-
593
- const g = svg.append('g')
594
- .attr('transform', `translate(${margin.left},${margin.top})`);
595
-
596
- const counts = data.counts.map((d: any) => ({
597
- date: new Date(d.timestamp),
598
- count: d.total_models
599
- })).sort((a: any, b: any) => a.date - b.date);
600
-
601
- const extent = d3.extent(counts, (d: any) => d.date) as [Date | undefined, Date | undefined];
602
- const minDate = extent[0];
603
- const maxDate = extent[1];
604
- if (!minDate || !maxDate) return;
605
-
606
- const xScale = d3.scaleTime()
607
- .domain([minDate, maxDate])
608
- .range([0, innerWidth]);
609
-
610
- const yScale = d3.scaleLinear()
611
- .domain([0, d3.max(counts, (d: any) => d.count) || 0] as [number, number])
612
- .range([innerHeight, 0])
613
- .nice();
614
-
615
- const line = d3.line<any>()
616
- .x(d => xScale(d.date))
617
- .y(d => yScale(d.count))
618
- .curve(d3.curveMonotoneX);
619
-
620
- g.append('path')
621
- .datum(counts)
622
- .attr('fill', 'none')
623
- .attr('stroke', '#4a90e2')
624
- .attr('stroke-width', 2)
625
- .attr('d', line);
626
-
627
- g.selectAll('circle')
628
- .data(counts)
629
- .enter()
630
- .append('circle')
631
- .attr('cx', (d: any) => xScale(d.date))
632
- .attr('cy', (d: any) => yScale(d.count))
633
- .attr('r', 3)
634
- .attr('fill', '#4a90e2')
635
- .on('mouseover', function(event, d: any) {
636
- d3.select(this).attr('r', 5);
637
- const tooltip = d3.select('body').append('div')
638
- .attr('class', 'plot-tooltip')
639
- .style('opacity', 0);
640
- tooltip.transition().duration(200).style('opacity', 0.9);
641
- tooltip.html(`${d.date.toLocaleDateString()}<br/>Models: ${d.count.toLocaleString()}`)
642
- .style('left', (event.pageX + 10) + 'px')
643
- .style('top', (event.pageY - 28) + 'px');
644
- })
645
- .on('mouseout', function() {
646
- d3.select(this).attr('r', 3);
647
- d3.selectAll('.plot-tooltip').remove();
648
- });
649
-
650
- const xAxis = d3.axisBottom(xScale).ticks(6);
651
- const yAxis = d3.axisLeft(yScale).tickFormat(d3.format('.2s'));
652
-
653
- g.append('g')
654
- .attr('transform', `translate(0,${innerHeight})`)
655
- .call(xAxis)
656
- .append('text')
657
- .attr('x', innerWidth / 2)
658
- .attr('y', 45)
659
- .attr('fill', 'currentColor')
660
- .style('text-anchor', 'middle')
661
- .style('font-size', '14px')
662
- .text('Date');
663
-
664
- g.append('g').call(yAxis)
665
- .append('text')
666
- .attr('transform', 'rotate(-90)')
667
- .attr('y', -45)
668
- .attr('x', -innerHeight / 2)
669
- .attr('fill', 'currentColor')
670
- .style('text-anchor', 'middle')
671
- .style('font-size', '14px')
672
- .text('Total Models');
673
-
674
- svg.append('text')
675
- .attr('x', width / 2)
676
- .attr('y', 20)
677
- .attr('text-anchor', 'middle')
678
- .style('font-size', '16px')
679
- .style('font-weight', 'bold')
680
- .text('Model Count Growth Over Time');
681
- })
682
- .catch(err => {
683
- console.error('Error fetching growth data:', err);
684
- svg.append('text')
685
- .attr('x', width / 2)
686
- .attr('y', height / 2)
687
- .attr('text-anchor', 'middle')
688
- .text('Error loading growth data');
689
- });
690
-
691
- }, [activePlot, width, height]);
692
-
693
- const plotOptions: { value: PlotType; label: string; description: string }[] = [
694
- { value: 'family-size', label: 'Family Size Distribution', description: 'Distribution of family tree sizes' },
695
- { value: 'similarity-comparison', label: 'Similarity Comparison', description: 'Sibling vs parent-child similarity' },
696
- { value: 'license-drift', label: 'License Drift', description: 'License changes across generations' },
697
- { value: 'model-card-length', label: 'Model Card Length', description: 'Model card length by generation' },
698
- { value: 'growth-timeline', label: 'Growth Timeline', description: 'Model count over time' },
699
- ];
700
-
701
- return (
702
- <div className="paper-plots">
703
- <div className="plot-selector">
704
- <h3>Paper Visualizations</h3>
705
- <div className="plot-buttons">
706
- {plotOptions.map(option => (
707
- <button
708
- key={option.value}
709
- className={`plot-button ${activePlot === option.value ? 'active' : ''}`}
710
- onClick={() => setActivePlot(option.value)}
711
- title={option.description}
712
- >
713
- {option.label}
714
- </button>
715
- ))}
716
- </div>
717
- </div>
718
-
719
- <div className="plot-container">
720
- {loading && <div className="plot-loading">Loading data...</div>}
721
- <svg
722
- ref={familySizeRef}
723
- width={width}
724
- height={height}
725
- style={{ display: activePlot === 'family-size' ? 'block' : 'none' }}
726
- />
727
- <svg
728
- ref={similarityRef}
729
- width={width}
730
- height={height}
731
- style={{ display: activePlot === 'similarity-comparison' ? 'block' : 'none' }}
732
- />
733
- <svg
734
- ref={licenseDriftRef}
735
- width={width}
736
- height={height}
737
- style={{ display: activePlot === 'license-drift' ? 'block' : 'none' }}
738
- />
739
- <svg
740
- ref={modelCardLengthRef}
741
- width={width}
742
- height={height}
743
- style={{ display: activePlot === 'model-card-length' ? 'block' : 'none' }}
744
- />
745
- <svg
746
- ref={growthTimelineRef}
747
- width={width}
748
- height={height}
749
- style={{ display: activePlot === 'growth-timeline' ? 'block' : 'none' }}
750
- />
751
- </div>
752
- </div>
753
- );
754
- }
755
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/components/ScatterPlot.tsx DELETED
@@ -1,7 +0,0 @@
1
- /**
2
- * Legacy Visx scatter plot - kept for reference.
3
- * Use EnhancedScatterPlot.tsx for D3.js implementation.
4
- */
5
- // This file is kept for compatibility but EnhancedScatterPlot is preferred
6
-
7
- export { default } from './EnhancedScatterPlot';
 
 
 
 
 
 
 
 
frontend/src/components/controls/ClusterFilter.css ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .cluster-filter {
2
+ margin-bottom: 1.5rem;
3
+ }
4
+
5
+ .cluster-filter-header {
6
+ margin-bottom: 0.75rem;
7
+ }
8
+
9
+ .cluster-filter-header h3 {
10
+ margin: 0;
11
+ font-size: 1rem;
12
+ font-weight: 600;
13
+ color: var(--text-primary, #1a1a1a);
14
+ }
15
+
16
+ .cluster-filter-search {
17
+ margin-bottom: 0.75rem;
18
+ }
19
+
20
+ .cluster-search-input {
21
+ width: 100%;
22
+ padding: 0.5rem;
23
+ border: 1px solid var(--border-color, #e0e0e0);
24
+ border-radius: 4px;
25
+ font-size: 0.9rem;
26
+ background: var(--bg-primary, #ffffff);
27
+ color: var(--text-primary, #1a1a1a);
28
+ }
29
+
30
+ .cluster-search-input:focus {
31
+ outline: none;
32
+ border-color: var(--accent-color, #4a90e2);
33
+ }
34
+
35
+ .cluster-filter-actions {
36
+ display: flex;
37
+ gap: 0.5rem;
38
+ margin-bottom: 0.75rem;
39
+ }
40
+
41
+ .cluster-action-btn {
42
+ flex: 1;
43
+ padding: 0.4rem 0.6rem;
44
+ border: 1px solid var(--border-color, #e0e0e0);
45
+ border-radius: 4px;
46
+ background: var(--bg-primary, #ffffff);
47
+ color: var(--text-primary, #1a1a1a);
48
+ font-size: 0.85rem;
49
+ cursor: pointer;
50
+ transition: all 0.2s;
51
+ }
52
+
53
+ .cluster-action-btn:hover:not(:disabled) {
54
+ background: var(--bg-secondary, #f5f5f5);
55
+ border-color: var(--accent-color, #4a90e2);
56
+ }
57
+
58
+ .cluster-action-btn:disabled {
59
+ opacity: 0.5;
60
+ cursor: not-allowed;
61
+ }
62
+
63
+ .cluster-list {
64
+ max-height: 300px;
65
+ overflow-y: auto;
66
+ border: 1px solid var(--border-color, #e0e0e0);
67
+ border-radius: 4px;
68
+ padding: 0.5rem;
69
+ background: var(--bg-primary, #ffffff);
70
+ }
71
+
72
+ .cluster-item {
73
+ display: flex;
74
+ align-items: center;
75
+ gap: 0.5rem;
76
+ padding: 0.5rem;
77
+ cursor: pointer;
78
+ border-radius: 3px;
79
+ transition: background 0.15s;
80
+ font-size: 0.85rem;
81
+ }
82
+
83
+ .cluster-item:hover {
84
+ background: var(--bg-secondary, #f5f5f5);
85
+ }
86
+
87
+ .cluster-item.selected {
88
+ background: var(--bg-secondary, #f5f5f5);
89
+ }
90
+
91
+ .cluster-checkbox {
92
+ margin: 0;
93
+ cursor: pointer;
94
+ }
95
+
96
+ .cluster-color-indicator {
97
+ width: 12px;
98
+ height: 12px;
99
+ border-radius: 2px;
100
+ flex-shrink: 0;
101
+ border: 1px solid var(--border-color, #e0e0e0);
102
+ }
103
+
104
+ .cluster-label {
105
+ flex: 1;
106
+ color: var(--text-primary, #1a1a1a);
107
+ }
108
+
109
+ .cluster-count {
110
+ font-size: 0.75rem;
111
+ color: var(--text-secondary, #666);
112
+ margin-left: auto;
113
+ }
114
+
115
+ .cluster-filter-loading,
116
+ .cluster-filter-empty {
117
+ padding: 1rem;
118
+ text-align: center;
119
+ color: var(--text-secondary, #666);
120
+ font-size: 0.9rem;
121
+ }
122
+
frontend/src/components/controls/ClusterFilter.tsx ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Enhanced cluster filter component with search, Select All/Clear All/Random buttons.
3
+ * Inspired by LAION's cluster filtering UI.
4
+ */
5
+ import React, { useState, useMemo } from 'react';
6
+ import { useFilterStore } from '../../stores/filterStore';
7
+ import './ClusterFilter.css';
8
+
9
+ export interface Cluster {
10
+ cluster_id: number;
11
+ cluster_label: string;
12
+ count: number;
13
+ color?: string;
14
+ }
15
+
16
+ interface ClusterFilterProps {
17
+ clusters: Cluster[];
18
+ loading?: boolean;
19
+ }
20
+
21
+ export default function ClusterFilter({ clusters, loading = false }: ClusterFilterProps) {
22
+ const { selectedClusters, setSelectedClusters } = useFilterStore();
23
+ const [searchTerm, setSearchTerm] = useState('');
24
+
25
+ const filteredClusters = useMemo(() => {
26
+ if (!searchTerm) return clusters;
27
+ const lowerSearch = searchTerm.toLowerCase();
28
+ return clusters.filter(c =>
29
+ c.cluster_label.toLowerCase().includes(lowerSearch) ||
30
+ c.cluster_id.toString().includes(lowerSearch)
31
+ );
32
+ }, [clusters, searchTerm]);
33
+
34
+ const handleSelectAll = () => {
35
+ setSelectedClusters(clusters.map(c => c.cluster_id));
36
+ };
37
+
38
+ const handleClearAll = () => {
39
+ setSelectedClusters([]);
40
+ };
41
+
42
+ const handleRandom = () => {
43
+ if (clusters.length === 0) return;
44
+ const randomCluster = clusters[Math.floor(Math.random() * clusters.length)];
45
+ setSelectedClusters([randomCluster.cluster_id]);
46
+ };
47
+
48
+ const handleToggleCluster = (clusterId: number) => {
49
+ if (selectedClusters.includes(clusterId)) {
50
+ setSelectedClusters(selectedClusters.filter(id => id !== clusterId));
51
+ } else {
52
+ setSelectedClusters([...selectedClusters, clusterId]);
53
+ }
54
+ };
55
+
56
+ if (loading) {
57
+ return (
58
+ <div className="cluster-filter">
59
+ <div className="cluster-filter-loading">Loading clusters...</div>
60
+ </div>
61
+ );
62
+ }
63
+
64
+ if (clusters.length === 0) {
65
+ return (
66
+ <div className="cluster-filter">
67
+ <div className="cluster-filter-empty">No clusters available</div>
68
+ </div>
69
+ );
70
+ }
71
+
72
+ return (
73
+ <div className="cluster-filter">
74
+ <div className="cluster-filter-header">
75
+ <h3>Dataset Clusters</h3>
76
+ </div>
77
+
78
+ <div className="cluster-filter-search">
79
+ <input
80
+ type="text"
81
+ placeholder={`Search ${clusters.length} clusters...`}
82
+ value={searchTerm}
83
+ onChange={(e) => setSearchTerm(e.target.value)}
84
+ className="cluster-search-input"
85
+ />
86
+ </div>
87
+
88
+ <div className="cluster-filter-actions">
89
+ <button
90
+ onClick={handleSelectAll}
91
+ className="cluster-action-btn"
92
+ disabled={clusters.length === 0}
93
+ >
94
+ Select All
95
+ </button>
96
+ <button
97
+ onClick={handleClearAll}
98
+ className="cluster-action-btn"
99
+ disabled={selectedClusters.length === 0}
100
+ >
101
+ Clear All
102
+ </button>
103
+ <button
104
+ onClick={handleRandom}
105
+ className="cluster-action-btn"
106
+ disabled={clusters.length === 0}
107
+ >
108
+ Random
109
+ </button>
110
+ </div>
111
+
112
+ <div className="cluster-list">
113
+ {filteredClusters.length === 0 ? (
114
+ <div className="cluster-filter-empty">No clusters match your search</div>
115
+ ) : (
116
+ filteredClusters.map(cluster => (
117
+ <label
118
+ key={cluster.cluster_id}
119
+ className={`cluster-item ${selectedClusters.includes(cluster.cluster_id) ? 'selected' : ''}`}
120
+ >
121
+ <input
122
+ type="checkbox"
123
+ checked={selectedClusters.includes(cluster.cluster_id)}
124
+ onChange={() => handleToggleCluster(cluster.cluster_id)}
125
+ className="cluster-checkbox"
126
+ />
127
+ {cluster.color && (
128
+ <span
129
+ className="cluster-color-indicator"
130
+ style={{ backgroundColor: cluster.color }}
131
+ />
132
+ )}
133
+ <span className="cluster-label">{cluster.cluster_label}</span>
134
+ <span className="cluster-count">({cluster.count.toLocaleString()})</span>
135
+ </label>
136
+ ))
137
+ )}
138
+ </div>
139
+ </div>
140
+ );
141
+ }
142
+
frontend/src/components/controls/NodeDensitySlider.css ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .node-density-slider {
2
+ margin-bottom: 1rem;
3
+ }
4
+
5
+ .node-density-label {
6
+ display: block;
7
+ }
8
+
9
+ .node-density-title {
10
+ font-weight: 500;
11
+ display: block;
12
+ margin-bottom: 0.5rem;
13
+ color: var(--text-primary, #1a1a1a);
14
+ }
15
+
16
+ .node-density-input {
17
+ width: 100%;
18
+ cursor: pointer;
19
+ }
20
+
21
+ .node-density-input:disabled {
22
+ opacity: 0.5;
23
+ cursor: not-allowed;
24
+ }
25
+
26
+ .node-density-hint {
27
+ font-size: 0.75rem;
28
+ color: var(--text-secondary, #666);
29
+ margin-top: 0.25rem;
30
+ }
31
+
frontend/src/components/controls/NodeDensitySlider.tsx ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Node density slider for controlling rendering performance.
3
+ * Lower density improves performance for large datasets.
4
+ */
5
+ import React from 'react';
6
+ import { useFilterStore } from '../../stores/filterStore';
7
+ import './NodeDensitySlider.css';
8
+
9
+ interface NodeDensitySliderProps {
10
+ disabled?: boolean;
11
+ }
12
+
13
+ export default function NodeDensitySlider({ disabled = false }: NodeDensitySliderProps) {
14
+ const { nodeDensity, setNodeDensity } = useFilterStore();
15
+
16
+ return (
17
+ <div className="node-density-slider">
18
+ <label className="node-density-label">
19
+ <span className="node-density-title">
20
+ Node Density ({nodeDensity}%)
21
+ </span>
22
+ <input
23
+ type="range"
24
+ min="10"
25
+ max="100"
26
+ step="10"
27
+ value={nodeDensity}
28
+ onChange={(e) => setNodeDensity(parseInt(e.target.value))}
29
+ disabled={disabled}
30
+ className="node-density-input"
31
+ />
32
+ <div className="node-density-hint">
33
+ Lower density improves performance for large datasets
34
+ </div>
35
+ </label>
36
+ </div>
37
+ );
38
+ }
39
+
frontend/src/components/controls/RandomModelButton.tsx ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Button to select a random model from the dataset for discovery.
3
+ */
4
+ import React from 'react';
5
+ import { ModelPoint } from '../../types';
6
+
7
+ interface RandomModelButtonProps {
8
+ data: ModelPoint[];
9
+ onSelect: (model: ModelPoint) => void;
10
+ disabled?: boolean;
11
+ }
12
+
13
+ export default function RandomModelButton({ data, onSelect, disabled }: RandomModelButtonProps) {
14
+ const handleRandomSelect = () => {
15
+ if (data.length === 0) return;
16
+ const randomIndex = Math.floor(Math.random() * data.length);
17
+ onSelect(data[randomIndex]);
18
+ };
19
+
20
+ return (
21
+ <button
22
+ onClick={handleRandomSelect}
23
+ disabled={disabled || data.length === 0}
24
+ className="random-model-btn"
25
+ title="Select a random model"
26
+ aria-label="Select random model"
27
+ >
28
+ <span>Select Random Model</span>
29
+ </button>
30
+ );
31
+ }
32
+
frontend/src/components/controls/RenderingStyleSelector.css ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .rendering-style-selector {
2
+ margin-bottom: 1rem;
3
+ }
4
+
5
+ .rendering-style-label {
6
+ display: block;
7
+ }
8
+
9
+ .rendering-style-title {
10
+ font-weight: 500;
11
+ display: block;
12
+ margin-bottom: 0.5rem;
13
+ color: var(--text-primary, #1a1a1a);
14
+ }
15
+
16
+ .rendering-style-select {
17
+ width: 100%;
18
+ padding: 0.5rem;
19
+ border: 1px solid var(--border-color, #e0e0e0);
20
+ border-radius: 4px;
21
+ background: var(--bg-primary, #ffffff);
22
+ color: var(--text-primary, #1a1a1a);
23
+ font-size: 0.9rem;
24
+ cursor: pointer;
25
+ }
26
+
27
+ .rendering-style-select:focus {
28
+ outline: none;
29
+ border-color: var(--accent-color, #4a90e2);
30
+ }
31
+
32
+ .rendering-style-hint {
33
+ font-size: 0.75rem;
34
+ color: var(--text-secondary, #666);
35
+ margin-top: 0.25rem;
36
+ }
37
+
frontend/src/components/controls/RenderingStyleSelector.tsx ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Rendering style selector dropdown.
3
+ * Allows users to choose different 3D layout/geometry styles.
4
+ */
5
+ import React from 'react';
6
+ import { useFilterStore, RenderingStyle } from '../../stores/filterStore';
7
+ import './RenderingStyleSelector.css';
8
+
9
+ const STYLES: { value: RenderingStyle; label: string; description: string }[] = [
10
+ { value: 'embeddings', label: 'Embeddings', description: 'Standard embedding-based layout' },
11
+ { value: 'sphere', label: 'Sphere', description: 'Spherical arrangement of points' },
12
+ { value: 'galaxy', label: 'Galaxy', description: 'Spiral galaxy-like layout' },
13
+ { value: 'wave', label: 'Wave', description: 'Wave pattern arrangement' },
14
+ { value: 'helix', label: 'Helix', description: 'Helical/spiral arrangement' },
15
+ { value: 'torus', label: 'Torus', description: 'Torus/donut-shaped layout' },
16
+ ];
17
+
18
+ export default function RenderingStyleSelector() {
19
+ const { renderingStyle, setRenderingStyle } = useFilterStore();
20
+
21
+ return (
22
+ <div className="rendering-style-selector">
23
+ <label className="rendering-style-label">
24
+ <span className="rendering-style-title">Rendering Style</span>
25
+ <select
26
+ value={renderingStyle}
27
+ onChange={(e) => setRenderingStyle(e.target.value as RenderingStyle)}
28
+ className="rendering-style-select"
29
+ >
30
+ {STYLES.map(style => (
31
+ <option key={style.value} value={style.value}>
32
+ {style.label}
33
+ </option>
34
+ ))}
35
+ </select>
36
+ <div className="rendering-style-hint">
37
+ {STYLES.find(s => s.value === renderingStyle)?.description}
38
+ </div>
39
+ </label>
40
+ </div>
41
+ );
42
+ }
43
+
frontend/src/components/controls/ThemeToggle.tsx ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Toggle button for switching between light and dark themes.
3
+ */
4
+ import React from 'react';
5
+ import { useFilterStore } from '../../stores/filterStore';
6
+
7
+ export default function ThemeToggle() {
8
+ const theme = useFilterStore((state) => state.theme);
9
+ const toggleTheme = useFilterStore((state) => state.toggleTheme);
10
+
11
+ return (
12
+ <button
13
+ onClick={toggleTheme}
14
+ className="theme-toggle"
15
+ title={`Switch to ${theme === 'light' ? 'dark' : 'light'} mode`}
16
+ aria-label={`Current theme: ${theme}. Click to switch to ${theme === 'light' ? 'dark' : 'light'} mode`}
17
+ >
18
+ {theme === 'light' ? 'πŸŒ™' : 'β˜€οΈ'}
19
+ </button>
20
+ );
21
+ }
22
+
frontend/src/components/controls/VisualizationModeButtons.css ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .visualization-mode-buttons {
2
+ position: sticky;
3
+ top: 0;
4
+ z-index: 100;
5
+ background: var(--bg-primary, #ffffff);
6
+ border-bottom: 1px solid var(--border-color, #e0e0e0);
7
+ padding: 0.75rem 1rem;
8
+ margin-bottom: 1rem;
9
+ }
10
+
11
+ .mode-buttons-container {
12
+ display: flex;
13
+ gap: 0.5rem;
14
+ flex-wrap: wrap;
15
+ justify-content: center;
16
+ }
17
+
18
+ .mode-button {
19
+ display: flex;
20
+ align-items: center;
21
+ gap: 0.5rem;
22
+ padding: 0.5rem 1rem;
23
+ border: 1px solid var(--border-color, #e0e0e0);
24
+ border-radius: 6px;
25
+ background: var(--bg-primary, #ffffff);
26
+ color: var(--text-primary, #1a1a1a);
27
+ font-size: 0.9rem;
28
+ cursor: pointer;
29
+ transition: all 0.2s;
30
+ }
31
+
32
+ .mode-button:hover {
33
+ background: var(--bg-secondary, #f5f5f5);
34
+ border-color: var(--accent-color, #4a90e2);
35
+ }
36
+
37
+ .mode-button.active {
38
+ background: var(--accent-color, #4a90e2);
39
+ color: white;
40
+ border-color: var(--accent-color, #4a90e2);
41
+ }
42
+
43
+ .mode-icon {
44
+ font-size: 1.1rem;
45
+ }
46
+
47
+ .mode-label {
48
+ font-weight: 500;
49
+ }
50
+
51
+ @media (max-width: 768px) {
52
+ .mode-buttons-container {
53
+ gap: 0.25rem;
54
+ }
55
+
56
+ .mode-button {
57
+ padding: 0.4rem 0.6rem;
58
+ font-size: 0.8rem;
59
+ }
60
+
61
+ .mode-label {
62
+ display: none;
63
+ }
64
+ }
65
+
frontend/src/components/controls/VisualizationModeButtons.tsx ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Visualization mode buttons with sticky header.
3
+ * Inspired by LAION's mode selection UI.
4
+ */
5
+ import React from 'react';
6
+ import { useFilterStore, ViewMode } from '../../stores/filterStore';
7
+ import './VisualizationModeButtons.css';
8
+
9
+ interface ModeOption {
10
+ value: ViewMode;
11
+ label: string;
12
+ icon: string;
13
+ description: string;
14
+ }
15
+
16
+ const MODES: ModeOption[] = [
17
+ { value: '3d', label: '3D Embedding', icon: '🎯', description: 'Interactive 3D exploration' },
18
+ { value: 'scatter', label: '2D Scatter', icon: 'πŸ“Š', description: '2D projection view' },
19
+ { value: 'network', label: 'Network', icon: 'πŸ•ΈοΈ', description: 'Network graph view' },
20
+ { value: 'distribution', label: 'Distribution', icon: 'πŸ“ˆ', description: 'Statistical distributions' },
21
+ { value: 'stacked', label: 'Stacked', icon: 'πŸ“š', description: 'Hierarchical view' },
22
+ { value: 'heatmap', label: 'Heatmap', icon: 'πŸ”₯', description: 'Density heatmap' },
23
+ ];
24
+
25
+ export default function VisualizationModeButtons() {
26
+ const { viewMode, setViewMode } = useFilterStore();
27
+
28
+ return (
29
+ <div className="visualization-mode-buttons">
30
+ <div className="mode-buttons-container">
31
+ {MODES.map(mode => (
32
+ <button
33
+ key={mode.value}
34
+ className={`mode-button ${viewMode === mode.value ? 'active' : ''}`}
35
+ onClick={() => setViewMode(mode.value)}
36
+ title={mode.description}
37
+ >
38
+ <span className="mode-icon">{mode.icon}</span>
39
+ <span className="mode-label">{mode.label}</span>
40
+ </button>
41
+ ))}
42
+ </div>
43
+ </div>
44
+ );
45
+ }
46
+
frontend/src/components/controls/ZoomSlider.tsx ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Slider control for zoom level in 3D visualization.
3
+ */
4
+ import React from 'react';
5
+
6
+ interface ZoomSliderProps {
7
+ value: number;
8
+ onChange: (value: number) => void;
9
+ min?: number;
10
+ max?: number;
11
+ step?: number;
12
+ disabled?: boolean;
13
+ }
14
+
15
+ export default function ZoomSlider({
16
+ value,
17
+ onChange,
18
+ min = 0.1,
19
+ max = 5,
20
+ step = 0.1,
21
+ disabled = false,
22
+ }: ZoomSliderProps) {
23
+ return (
24
+ <div className="zoom-slider-container">
25
+ <label className="zoom-slider-label">
26
+ <span>Zoom Level</span>
27
+ <span className="zoom-value">{value.toFixed(1)}x</span>
28
+ </label>
29
+ <input
30
+ type="range"
31
+ min={min}
32
+ max={max}
33
+ step={step}
34
+ value={value}
35
+ onChange={(e) => onChange(parseFloat(e.target.value))}
36
+ disabled={disabled}
37
+ className="zoom-slider"
38
+ aria-label="Zoom level"
39
+ />
40
+ </div>
41
+ );
42
+ }
43
+
frontend/src/components/layout/SearchBar.css ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .search-bar-container {
2
+ position: relative;
3
+ width: 100%;
4
+ max-width: 600px;
5
+ z-index: 1000;
6
+ }
7
+
8
+ .search-bar {
9
+ position: relative;
10
+ display: flex;
11
+ align-items: center;
12
+ background: white;
13
+ border: 2px solid #e0e0e0;
14
+ border-radius: 8px;
15
+ padding: 8px 12px;
16
+ transition: border-color 0.2s;
17
+ }
18
+
19
+ .search-bar:focus-within {
20
+ border-color: #4a90e2;
21
+ box-shadow: 0 0 0 3px rgba(74, 144, 226, 0.1);
22
+ }
23
+
24
+ .search-input {
25
+ flex: 1;
26
+ border: none;
27
+ outline: none;
28
+ font-size: 14px;
29
+ font-family: 'Instrument Sans', sans-serif;
30
+ color: #333;
31
+ background: transparent;
32
+ }
33
+
34
+ .search-input::placeholder {
35
+ color: #999;
36
+ }
37
+
38
+ .search-loading {
39
+ margin-left: 8px;
40
+ color: #4a90e2;
41
+ animation: spin 1s linear infinite;
42
+ font-size: 16px;
43
+ }
44
+
45
+ .search-clear {
46
+ margin-left: 8px;
47
+ background: none;
48
+ border: none;
49
+ color: #999;
50
+ cursor: pointer;
51
+ font-size: 20px;
52
+ line-height: 1;
53
+ padding: 0;
54
+ width: 20px;
55
+ height: 20px;
56
+ display: flex;
57
+ align-items: center;
58
+ justify-content: center;
59
+ transition: color 0.2s;
60
+ }
61
+
62
+ .search-clear:hover {
63
+ color: #333;
64
+ }
65
+
66
+ .search-results {
67
+ position: absolute;
68
+ top: 100%;
69
+ left: 0;
70
+ right: 0;
71
+ margin-top: 4px;
72
+ background: white;
73
+ border: 1px solid #e0e0e0;
74
+ border-radius: 8px;
75
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
76
+ max-height: 400px;
77
+ overflow-y: auto;
78
+ z-index: 1001;
79
+ }
80
+
81
+ .search-result {
82
+ padding: 12px;
83
+ cursor: pointer;
84
+ border-bottom: 1px solid #f0f0f0;
85
+ transition: background-color 0.15s;
86
+ }
87
+
88
+ .search-result:last-child {
89
+ border-bottom: none;
90
+ }
91
+
92
+ .search-result:hover,
93
+ .search-result.selected {
94
+ background-color: #f5f5f5;
95
+ }
96
+
97
+ .result-header {
98
+ display: flex;
99
+ align-items: center;
100
+ gap: 8px;
101
+ margin-bottom: 6px;
102
+ }
103
+
104
+ .result-model-id {
105
+ font-size: 14px;
106
+ font-weight: 600;
107
+ color: #333;
108
+ font-family: 'Instrument Sans', sans-serif;
109
+ }
110
+
111
+ .result-org {
112
+ font-size: 12px;
113
+ color: #666;
114
+ background: #f0f0f0;
115
+ padding: 2px 6px;
116
+ border-radius: 4px;
117
+ }
118
+
119
+ .result-meta {
120
+ display: flex;
121
+ flex-wrap: wrap;
122
+ gap: 6px;
123
+ margin-bottom: 4px;
124
+ }
125
+
126
+ .result-tag {
127
+ font-size: 11px;
128
+ color: #666;
129
+ background: #e8e8e8;
130
+ padding: 2px 6px;
131
+ border-radius: 3px;
132
+ }
133
+
134
+ .result-snippet {
135
+ font-size: 12px;
136
+ color: #666;
137
+ margin-top: 4px;
138
+ line-height: 1.4;
139
+ }
140
+
141
+ .result-snippet mark {
142
+ background: #fff3cd;
143
+ padding: 1px 2px;
144
+ border-radius: 2px;
145
+ }
146
+
147
+ .search-no-results {
148
+ padding: 16px;
149
+ text-align: center;
150
+ color: #999;
151
+ font-size: 14px;
152
+ }
153
+
154
+ @keyframes spin {
155
+ from {
156
+ transform: rotate(0deg);
157
+ }
158
+ to {
159
+ transform: rotate(360deg);
160
+ }
161
+ }
162
+
163
+ /* Scrollbar styling */
164
+ .search-results::-webkit-scrollbar {
165
+ width: 8px;
166
+ }
167
+
168
+ .search-results::-webkit-scrollbar-track {
169
+ background: #f1f1f1;
170
+ border-radius: 4px;
171
+ }
172
+
173
+ .search-results::-webkit-scrollbar-thumb {
174
+ background: #c1c1c1;
175
+ border-radius: 4px;
176
+ }
177
+
178
+ .search-results::-webkit-scrollbar-thumb:hover {
179
+ background: #a8a8a8;
180
+ }
181
+
frontend/src/components/layout/SearchBar.tsx ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Enhanced search bar with autocomplete and keyboard navigation.
3
+ * Integrates with filter store and triggers map zoom/modal open.
4
+ */
5
+ import React, { useState, useEffect, useRef, useCallback } from 'react';
6
+ import { useFilterStore } from '../../stores/filterStore';
7
+ import './SearchBar.css';
8
+
9
+ import { API_BASE } from '../../config/api';
10
+
11
+ interface SearchResult {
12
+ model_id: string;
13
+ x: number;
14
+ y: number;
15
+ z: number;
16
+ org: string;
17
+ library?: string;
18
+ pipeline?: string;
19
+ license?: string;
20
+ snippet?: string;
21
+ match_score?: number;
22
+ }
23
+
24
+ interface SearchBarProps {
25
+ onSelect?: (result: SearchResult) => void;
26
+ onZoomTo?: (x: number, y: number, z: number) => void;
27
+ }
28
+
29
+ export default function SearchBar({ onSelect, onZoomTo }: SearchBarProps) {
30
+ const [query, setQuery] = useState('');
31
+ const [results, setResults] = useState<SearchResult[]>([]);
32
+ const [selectedIndex, setSelectedIndex] = useState(-1);
33
+ const [isOpen, setIsOpen] = useState(false);
34
+ const [isLoading, setIsLoading] = useState(false);
35
+ const inputRef = useRef<HTMLInputElement>(null);
36
+ const resultsRef = useRef<HTMLDivElement>(null);
37
+
38
+ const setSearchQuery = useFilterStore((state) => state.setSearchQuery);
39
+
40
+ // Debounced search
41
+ useEffect(() => {
42
+ if (query.length < 2) {
43
+ setResults([]);
44
+ setIsOpen(false);
45
+ return;
46
+ }
47
+
48
+ setIsLoading(true);
49
+ const timer = setTimeout(async () => {
50
+ try {
51
+ const response = await fetch(
52
+ `${API_BASE}/api/search?q=${encodeURIComponent(query)}&limit=20`
53
+ );
54
+ if (!response.ok) throw new Error('Search failed');
55
+ const data = await response.json();
56
+ setResults(data.results || []);
57
+ setIsOpen(true);
58
+ setSelectedIndex(-1);
59
+ } catch (err) {
60
+ console.error('Search error:', err);
61
+ setResults([]);
62
+ } finally {
63
+ setIsLoading(false);
64
+ }
65
+ }, 150);
66
+
67
+ return () => clearTimeout(timer);
68
+ }, [query]);
69
+
70
+ const handleSelect = useCallback((result: SearchResult) => {
71
+ setSearchQuery(result.model_id);
72
+
73
+ // Trigger zoom if coordinates available
74
+ if (onZoomTo && result.x !== undefined && result.y !== undefined) {
75
+ onZoomTo(result.x, result.y, result.z || 0);
76
+ }
77
+
78
+ // Trigger select callback
79
+ if (onSelect) {
80
+ onSelect(result);
81
+ }
82
+
83
+ setIsOpen(false);
84
+ setQuery('');
85
+ inputRef.current?.blur();
86
+ }, [onSelect, onZoomTo, setSearchQuery]);
87
+
88
+ const handleKeyDown = (e: React.KeyboardEvent) => {
89
+ if (!isOpen || results.length === 0) return;
90
+
91
+ if (e.key === 'ArrowDown') {
92
+ e.preventDefault();
93
+ setSelectedIndex(prev =>
94
+ prev < results.length - 1 ? prev + 1 : prev
95
+ );
96
+ // Scroll into view
97
+ if (resultsRef.current && selectedIndex >= 0) {
98
+ const selectedElement = resultsRef.current.children[selectedIndex + 1] as HTMLElement;
99
+ selectedElement?.scrollIntoView({ block: 'nearest' });
100
+ }
101
+ } else if (e.key === 'ArrowUp') {
102
+ e.preventDefault();
103
+ setSelectedIndex(prev => prev > 0 ? prev - 1 : -1);
104
+ } else if (e.key === 'Enter') {
105
+ e.preventDefault();
106
+ if (selectedIndex >= 0 && results[selectedIndex]) {
107
+ handleSelect(results[selectedIndex]);
108
+ } else if (results.length > 0) {
109
+ handleSelect(results[0]);
110
+ }
111
+ } else if (e.key === 'Escape') {
112
+ setIsOpen(false);
113
+ inputRef.current?.blur();
114
+ }
115
+ };
116
+
117
+ const handleFocus = () => {
118
+ if (results.length > 0) {
119
+ setIsOpen(true);
120
+ }
121
+ };
122
+
123
+ const handleBlur = (e: React.FocusEvent) => {
124
+ // Delay to allow click events on results
125
+ setTimeout(() => {
126
+ if (!resultsRef.current?.contains(document.activeElement)) {
127
+ setIsOpen(false);
128
+ }
129
+ }, 200);
130
+ };
131
+
132
+ return (
133
+ <div className="search-bar-container">
134
+ <div className="search-bar">
135
+ <input
136
+ ref={inputRef}
137
+ type="text"
138
+ value={query}
139
+ onChange={(e) => setQuery(e.target.value)}
140
+ onKeyDown={handleKeyDown}
141
+ onFocus={handleFocus}
142
+ onBlur={handleBlur}
143
+ placeholder="Search models, orgs, tasks, licenses..."
144
+ className="search-input"
145
+ aria-label="Search models"
146
+ aria-expanded={isOpen}
147
+ aria-haspopup="listbox"
148
+ />
149
+ {isLoading && <div className="search-loading">⟳</div>}
150
+ {query.length > 0 && !isLoading && (
151
+ <button
152
+ className="search-clear"
153
+ onClick={() => {
154
+ setQuery('');
155
+ setResults([]);
156
+ setIsOpen(false);
157
+ }}
158
+ aria-label="Clear search"
159
+ >
160
+ Γ—
161
+ </button>
162
+ )}
163
+ </div>
164
+ {isOpen && results.length > 0 && (
165
+ <div ref={resultsRef} className="search-results" role="listbox">
166
+ {results.map((result, idx) => (
167
+ <div
168
+ key={result.model_id}
169
+ className={`search-result ${idx === selectedIndex ? 'selected' : ''}`}
170
+ onClick={() => handleSelect(result)}
171
+ role="option"
172
+ aria-selected={idx === selectedIndex}
173
+ >
174
+ <div className="result-header">
175
+ <strong className="result-model-id">{result.model_id}</strong>
176
+ {result.org && <span className="result-org">{result.org}</span>}
177
+ </div>
178
+ <div className="result-meta">
179
+ {result.library && <span className="result-tag">{result.library}</span>}
180
+ {result.pipeline && <span className="result-tag">{result.pipeline}</span>}
181
+ {result.license && <span className="result-tag">{result.license}</span>}
182
+ </div>
183
+ {result.snippet && (
184
+ <div
185
+ className="result-snippet"
186
+ dangerouslySetInnerHTML={{ __html: result.snippet }}
187
+ />
188
+ )}
189
+ </div>
190
+ ))}
191
+ </div>
192
+ )}
193
+ {isOpen && query.length >= 2 && results.length === 0 && !isLoading && (
194
+ <div className="search-results">
195
+ <div className="search-no-results">No results found</div>
196
+ </div>
197
+ )}
198
+ </div>
199
+ );
200
+ }
201
+
frontend/src/components/{FileTree.css β†’ modals/FileTree.css} RENAMED
@@ -3,8 +3,11 @@
3
  border: 1px solid #e0e0e0;
4
  border-radius: 4px;
5
  background: #fafafa;
6
- max-height: 400px;
7
  overflow-y: auto;
 
 
 
8
  }
9
 
10
  .file-tree-header {
@@ -16,6 +19,19 @@
16
  border-bottom: 1px solid #e0e0e0;
17
  font-size: 0.9rem;
18
  font-weight: 600;
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
21
  .file-tree-link {
@@ -23,16 +39,110 @@
23
  text-decoration: none;
24
  font-size: 0.85rem;
25
  font-weight: 400;
 
26
  }
27
 
28
  .file-tree-link:hover {
29
  text-decoration: underline;
30
  }
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  .file-tree {
33
  padding: 0.5rem;
34
  font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
35
  font-size: 0.85rem;
 
 
36
  }
37
 
38
  .file-tree-node {
@@ -43,9 +153,14 @@
43
  display: flex;
44
  align-items: center;
45
  gap: 0.5rem;
46
- padding: 0.25rem 0.5rem;
47
- border-radius: 2px;
48
  transition: background 0.15s;
 
 
 
 
 
49
  }
50
 
51
  .file-tree-item.directory:hover {
@@ -56,6 +171,37 @@
56
  background: #f0f0f0;
57
  }
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  .file-icon {
60
  font-size: 1rem;
61
  width: 1.25rem;
@@ -83,6 +229,9 @@
83
 
84
  .file-tree-children {
85
  margin-left: 0.5rem;
 
 
 
86
  }
87
 
88
  .file-tree-loading,
@@ -98,3 +247,22 @@
98
  color: #d32f2f;
99
  }
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  border: 1px solid #e0e0e0;
4
  border-radius: 4px;
5
  background: #fafafa;
6
+ max-height: 600px;
7
  overflow-y: auto;
8
+ overflow-x: hidden;
9
+ display: flex;
10
+ flex-direction: column;
11
  }
12
 
13
  .file-tree-header {
 
19
  border-bottom: 1px solid #e0e0e0;
20
  font-size: 0.9rem;
21
  font-weight: 600;
22
+ flex-shrink: 0;
23
+ position: sticky;
24
+ top: 0;
25
+ z-index: 10;
26
+ }
27
+
28
+ .file-count-badge {
29
+ background: #e3f2fd;
30
+ color: #1976d2;
31
+ padding: 0.2rem 0.5rem;
32
+ border-radius: 12px;
33
+ font-size: 0.75rem;
34
+ font-weight: 500;
35
  }
36
 
37
  .file-tree-link {
 
39
  text-decoration: none;
40
  font-size: 0.85rem;
41
  font-weight: 400;
42
+ white-space: nowrap;
43
  }
44
 
45
  .file-tree-link:hover {
46
  text-decoration: underline;
47
  }
48
 
49
+ .file-tree-button {
50
+ background: #f0f0f0;
51
+ border: 1px solid #d0d0d0;
52
+ border-radius: 3px;
53
+ padding: 0.25rem 0.5rem;
54
+ font-size: 0.75rem;
55
+ cursor: pointer;
56
+ color: #333;
57
+ font-family: 'Instrument Sans', sans-serif;
58
+ transition: background 0.15s;
59
+ }
60
+
61
+ .file-tree-button:hover {
62
+ background: #e0e0e0;
63
+ }
64
+
65
+ .file-tree-button:active {
66
+ background: #d0d0d0;
67
+ }
68
+
69
+ .file-tree-filters {
70
+ padding: 0.75rem 1rem;
71
+ background: #ffffff;
72
+ border-bottom: 1px solid #e0e0e0;
73
+ display: flex;
74
+ gap: 0.75rem;
75
+ flex-shrink: 0;
76
+ position: sticky;
77
+ top: 48px;
78
+ z-index: 9;
79
+ }
80
+
81
+ .file-tree-search {
82
+ flex: 1;
83
+ position: relative;
84
+ display: flex;
85
+ align-items: center;
86
+ }
87
+
88
+ .file-tree-search-input {
89
+ width: 100%;
90
+ padding: 0.5rem 2rem 0.5rem 0.75rem;
91
+ border: 1px solid #d0d0d0;
92
+ border-radius: 4px;
93
+ font-size: 0.85rem;
94
+ font-family: 'Instrument Sans', sans-serif;
95
+ }
96
+
97
+ .file-tree-search-input:focus {
98
+ outline: none;
99
+ border-color: #4a90e2;
100
+ box-shadow: 0 0 0 2px rgba(74, 144, 226, 0.1);
101
+ }
102
+
103
+ .file-tree-clear {
104
+ position: absolute;
105
+ right: 0.5rem;
106
+ background: none;
107
+ border: none;
108
+ cursor: pointer;
109
+ color: #666;
110
+ font-size: 1rem;
111
+ padding: 0.25rem;
112
+ display: flex;
113
+ align-items: center;
114
+ justify-content: center;
115
+ border-radius: 2px;
116
+ }
117
+
118
+ .file-tree-clear:hover {
119
+ background: #f0f0f0;
120
+ color: #1a1a1a;
121
+ }
122
+
123
+ .file-tree-type-filter {
124
+ padding: 0.5rem 0.75rem;
125
+ border: 1px solid #d0d0d0;
126
+ border-radius: 4px;
127
+ font-size: 0.85rem;
128
+ font-family: 'Instrument Sans', sans-serif;
129
+ background: white;
130
+ cursor: pointer;
131
+ min-width: 150px;
132
+ }
133
+
134
+ .file-tree-type-filter:focus {
135
+ outline: none;
136
+ border-color: #4a90e2;
137
+ box-shadow: 0 0 0 2px rgba(74, 144, 226, 0.1);
138
+ }
139
+
140
  .file-tree {
141
  padding: 0.5rem;
142
  font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
143
  font-size: 0.85rem;
144
+ flex: 1;
145
+ overflow-y: auto;
146
  }
147
 
148
  .file-tree-node {
 
153
  display: flex;
154
  align-items: center;
155
  gap: 0.5rem;
156
+ padding: 0.375rem 0.5rem;
157
+ border-radius: 3px;
158
  transition: background 0.15s;
159
+ user-select: none;
160
+ }
161
+
162
+ .file-tree-item.directory {
163
+ cursor: pointer;
164
  }
165
 
166
  .file-tree-item.directory:hover {
 
171
  background: #f0f0f0;
172
  }
173
 
174
+ .file-actions {
175
+ display: flex;
176
+ gap: 0.25rem;
177
+ margin-left: auto;
178
+ opacity: 0;
179
+ transition: opacity 0.2s;
180
+ }
181
+
182
+ .file-tree-item:hover .file-actions {
183
+ opacity: 1;
184
+ }
185
+
186
+ .file-action-btn {
187
+ background: none;
188
+ border: none;
189
+ cursor: pointer;
190
+ font-size: 0.9rem;
191
+ padding: 0.25rem;
192
+ border-radius: 2px;
193
+ display: flex;
194
+ align-items: center;
195
+ justify-content: center;
196
+ transition: background 0.15s;
197
+ text-decoration: none;
198
+ color: inherit;
199
+ }
200
+
201
+ .file-action-btn:hover {
202
+ background: rgba(0, 0, 0, 0.1);
203
+ }
204
+
205
  .file-icon {
206
  font-size: 1rem;
207
  width: 1.25rem;
 
229
 
230
  .file-tree-children {
231
  margin-left: 0.5rem;
232
+ border-left: 1px solid #e8e8e8;
233
+ padding-left: 0.5rem;
234
+ margin-top: 0.125rem;
235
  }
236
 
237
  .file-tree-loading,
 
247
  color: #d32f2f;
248
  }
249
 
250
+ /* Scrollbar styling */
251
+ .file-tree-container::-webkit-scrollbar {
252
+ width: 8px;
253
+ }
254
+
255
+ .file-tree-container::-webkit-scrollbar-track {
256
+ background: #f1f1f1;
257
+ border-radius: 4px;
258
+ }
259
+
260
+ .file-tree-container::-webkit-scrollbar-thumb {
261
+ background: #c1c1c1;
262
+ border-radius: 4px;
263
+ }
264
+
265
+ .file-tree-container::-webkit-scrollbar-thumb:hover {
266
+ background: #a8a8a8;
267
+ }
268
+
frontend/src/components/{FileTree.tsx β†’ modals/FileTree.tsx} RENAMED
@@ -2,9 +2,12 @@
2
  * File tree component for displaying model file structure.
3
  * Fetches and displays files from Hugging Face model repository.
4
  */
5
- import React, { useState, useEffect } from 'react';
 
6
  import './FileTree.css';
7
 
 
 
8
  interface FileNode {
9
  path: string;
10
  type: 'file' | 'directory';
@@ -21,50 +24,84 @@ export default function FileTree({ modelId }: FileTreeProps) {
21
  const [loading, setLoading] = useState(true);
22
  const [error, setError] = useState<string | null>(null);
23
  const [expandedPaths, setExpandedPaths] = useState<Set<string>>(new Set());
 
 
 
 
24
 
25
  useEffect(() => {
 
 
 
 
 
 
26
  const fetchFiles = async () => {
27
  setLoading(true);
28
  setError(null);
29
  try {
30
- // Fetch file tree through our backend API (avoids CORS issues)
31
- // Use same API base as main app
32
- const apiBase = (window as any).__API_BASE__ || process.env.REACT_APP_API_URL || 'http://localhost:8000';
33
  const response = await fetch(
34
- `${apiBase}/api/model/${encodeURIComponent(modelId)}/files?branch=main`
35
  );
36
 
37
- if (!response.ok) {
38
  throw new Error('File tree not available for this model');
39
  }
 
 
 
 
 
 
 
 
 
40
 
41
  const data = await response.json();
42
 
 
 
 
 
43
  // Convert flat list to tree structure
44
  const tree = buildFileTree(data);
45
  setFiles(tree);
46
  } catch (err: any) {
47
- setError(err instanceof Error ? err.message : 'Failed to load files');
48
- console.error('Error fetching file tree:', err);
 
 
 
 
49
  } finally {
50
  setLoading(false);
51
  }
52
  };
53
 
54
- if (modelId) {
55
- fetchFiles();
56
- }
57
  }, [modelId]);
58
 
59
  const buildFileTree = (fileList: any[]): FileNode[] => {
 
 
 
 
60
  const tree: FileNode[] = [];
61
  const pathMap = new Map<string, FileNode>();
62
 
63
- // Sort files by path
64
- const sortedFiles = [...fileList].sort((a, b) => a.path.localeCompare(b.path));
 
 
 
 
65
 
66
  for (const file of sortedFiles) {
67
- const parts = file.path.split('/');
 
 
 
 
68
  let currentPath = '';
69
  let parent: FileNode | null = null;
70
 
@@ -77,7 +114,7 @@ export default function FileTree({ modelId }: FileTreeProps) {
77
  const node: FileNode = {
78
  path: currentPath,
79
  type: isDirectory ? 'directory' : 'file',
80
- size: file.size,
81
  children: isDirectory ? [] : undefined,
82
  };
83
 
@@ -111,6 +148,26 @@ export default function FileTree({ modelId }: FileTreeProps) {
111
  });
112
  };
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  const formatFileSize = (bytes?: number): string => {
115
  if (!bytes) return '';
116
  if (bytes < 1024) return `${bytes} B`;
@@ -119,6 +176,109 @@ export default function FileTree({ modelId }: FileTreeProps) {
119
  return `${(bytes / (1024 * 1024 * 1024)).toFixed(1)} GB`;
120
  };
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  const getFileIcon = (node: FileNode): string => {
123
  if (node.type === 'directory') {
124
  return expandedPaths.has(node.path) ? 'πŸ“‚' : 'πŸ“';
@@ -141,9 +301,28 @@ export default function FileTree({ modelId }: FileTreeProps) {
141
  return iconMap[ext || ''] || 'πŸ“„';
142
  };
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  const renderNode = (node: FileNode, depth: number = 0): React.ReactNode => {
145
  const isExpanded = expandedPaths.has(node.path);
146
  const hasChildren = node.children && node.children.length > 0;
 
147
 
148
  return (
149
  <div key={node.path} className="file-tree-node" style={{ paddingLeft: `${depth * 1.5}rem` }}>
@@ -153,13 +332,37 @@ export default function FileTree({ modelId }: FileTreeProps) {
153
  style={{ cursor: node.type === 'directory' ? 'pointer' : 'default' }}
154
  >
155
  <span className="file-icon">{getFileIcon(node)}</span>
156
- <span className="file-name">{node.path.split('/').pop()}</span>
157
  {node.type === 'file' && node.size && (
158
  <span className="file-size">{formatFileSize(node.size)}</span>
159
  )}
160
  {node.type === 'directory' && (
161
  <span className="file-expand">{isExpanded ? 'β–Ό' : 'β–Ά'}</span>
162
  )}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  </div>
164
  {isExpanded && hasChildren && (
165
  <div className="file-tree-children">
@@ -199,21 +402,106 @@ export default function FileTree({ modelId }: FileTreeProps) {
199
  );
200
  }
201
 
 
 
202
  return (
203
  <div className="file-tree-container">
204
  <div className="file-tree-header">
205
- <strong>Repository Files</strong>
206
- <a
207
- href={`https://huggingface.co/${modelId}/tree/main`}
208
- target="_blank"
209
- rel="noopener noreferrer"
210
- className="file-tree-link"
211
- >
212
- View on Hugging Face β†’
213
- </a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  <div className="file-tree">
216
- {files.map((node) => renderNode(node))}
 
 
 
 
 
 
 
 
217
  </div>
218
  </div>
219
  );
 
2
  * File tree component for displaying model file structure.
3
  * Fetches and displays files from Hugging Face model repository.
4
  */
5
+ import React, { useState, useEffect, useMemo } from 'react';
6
+ import { getHuggingFaceFileTreeUrl } from '../../utils/api/hfUrl';
7
  import './FileTree.css';
8
 
9
+ import { API_BASE } from '../../config/api';
10
+
11
  interface FileNode {
12
  path: string;
13
  type: 'file' | 'directory';
 
24
  const [loading, setLoading] = useState(true);
25
  const [error, setError] = useState<string | null>(null);
26
  const [expandedPaths, setExpandedPaths] = useState<Set<string>>(new Set());
27
+ const [searchQuery, setSearchQuery] = useState('');
28
+ const [fileTypeFilter, setFileTypeFilter] = useState<string>('all');
29
+ const [showSearch, setShowSearch] = useState(false);
30
+ const searchInputRef = React.useRef<HTMLInputElement>(null);
31
 
32
  useEffect(() => {
33
+ if (!modelId) {
34
+ setLoading(false);
35
+ setError('No model ID provided');
36
+ return;
37
+ }
38
+
39
  const fetchFiles = async () => {
40
  setLoading(true);
41
  setError(null);
42
  try {
 
 
 
43
  const response = await fetch(
44
+ `${API_BASE}/api/model/${encodeURIComponent(modelId)}/files?branch=main`
45
  );
46
 
47
+ if (response.status === 404) {
48
  throw new Error('File tree not available for this model');
49
  }
50
+
51
+ if (response.status === 503) {
52
+ throw new Error('Backend service unavailable');
53
+ }
54
+
55
+ if (!response.ok) {
56
+ const errorText = await response.text();
57
+ throw new Error(`Failed to load file tree: ${response.status} ${errorText}`);
58
+ }
59
 
60
  const data = await response.json();
61
 
62
+ if (!Array.isArray(data)) {
63
+ throw new Error('Invalid response format');
64
+ }
65
+
66
  // Convert flat list to tree structure
67
  const tree = buildFileTree(data);
68
  setFiles(tree);
69
  } catch (err: any) {
70
+ const errorMessage = err instanceof Error ? err.message : 'Failed to load files';
71
+ setError(errorMessage);
72
+ // Only log in development
73
+ if (process.env.NODE_ENV === 'development') {
74
+ console.error('Error fetching file tree:', err);
75
+ }
76
  } finally {
77
  setLoading(false);
78
  }
79
  };
80
 
81
+ fetchFiles();
 
 
82
  }, [modelId]);
83
 
84
  const buildFileTree = (fileList: any[]): FileNode[] => {
85
+ if (!Array.isArray(fileList) || fileList.length === 0) {
86
+ return [];
87
+ }
88
+
89
  const tree: FileNode[] = [];
90
  const pathMap = new Map<string, FileNode>();
91
 
92
+ // Sort files by path for consistent ordering
93
+ const sortedFiles = [...fileList].sort((a, b) => {
94
+ const pathA = a.path || '';
95
+ const pathB = b.path || '';
96
+ return pathA.localeCompare(pathB);
97
+ });
98
 
99
  for (const file of sortedFiles) {
100
+ if (!file.path) continue;
101
+
102
+ const parts = file.path.split('/').filter((p: string) => p.length > 0);
103
+ if (parts.length === 0) continue;
104
+
105
  let currentPath = '';
106
  let parent: FileNode | null = null;
107
 
 
114
  const node: FileNode = {
115
  path: currentPath,
116
  type: isDirectory ? 'directory' : 'file',
117
+ size: isDirectory ? undefined : (file.size || undefined), // Only set size for files
118
  children: isDirectory ? [] : undefined,
119
  };
120
 
 
148
  });
149
  };
150
 
151
+ const expandAll = () => {
152
+ const allPaths = new Set<string>();
153
+ const collectPaths = (nodes: FileNode[]) => {
154
+ nodes.forEach(node => {
155
+ if (node.type === 'directory' && node.children) {
156
+ allPaths.add(node.path);
157
+ if (node.children.length > 0) {
158
+ collectPaths(node.children);
159
+ }
160
+ }
161
+ });
162
+ };
163
+ collectPaths(files);
164
+ setExpandedPaths(allPaths);
165
+ };
166
+
167
+ const collapseAll = () => {
168
+ setExpandedPaths(new Set());
169
+ };
170
+
171
  const formatFileSize = (bytes?: number): string => {
172
  if (!bytes) return '';
173
  if (bytes < 1024) return `${bytes} B`;
 
176
  return `${(bytes / (1024 * 1024 * 1024)).toFixed(1)} GB`;
177
  };
178
 
179
+ // Get all file extensions from the tree
180
+ const getAllFileExtensions = useMemo(() => {
181
+ const extensions = new Set<string>();
182
+ const collectExtensions = (nodes: FileNode[]) => {
183
+ nodes.forEach(node => {
184
+ if (node.type === 'file') {
185
+ const ext = node.path.split('.').pop()?.toLowerCase();
186
+ if (ext) extensions.add(ext);
187
+ }
188
+ if (node.children) {
189
+ collectExtensions(node.children);
190
+ }
191
+ });
192
+ };
193
+ collectExtensions(files);
194
+ return Array.from(extensions).sort();
195
+ }, [files]);
196
+
197
+ // Auto-expand directories when searching
198
+ useEffect(() => {
199
+ if (searchQuery) {
200
+ const pathsToExpand = new Set<string>();
201
+ const findMatchingPaths = (nodes: FileNode[], query: string) => {
202
+ nodes.forEach(node => {
203
+ if (node.path.toLowerCase().includes(query.toLowerCase())) {
204
+ // Expand all parent directories
205
+ const parts = node.path.split('/');
206
+ let currentPath = '';
207
+ for (let i = 0; i < parts.length - 1; i++) {
208
+ currentPath = currentPath ? `${currentPath}/${parts[i]}` : parts[i];
209
+ pathsToExpand.add(currentPath);
210
+ }
211
+ }
212
+ if (node.children) {
213
+ findMatchingPaths(node.children, query);
214
+ }
215
+ });
216
+ };
217
+ findMatchingPaths(files, searchQuery);
218
+ setExpandedPaths(pathsToExpand);
219
+ }
220
+ }, [searchQuery, files]);
221
+
222
+ // Filter files based on search and file type
223
+ const filterNodes = (nodes: FileNode[]): FileNode[] => {
224
+ return nodes
225
+ .map(node => {
226
+ const matchesSearch = !searchQuery ||
227
+ node.path.toLowerCase().includes(searchQuery.toLowerCase());
228
+ const matchesType = fileTypeFilter === 'all' ||
229
+ (node.type === 'file' && node.path.toLowerCase().endsWith(`.${fileTypeFilter}`)) ||
230
+ (node.type === 'directory');
231
+
232
+ if (!matchesSearch || !matchesType) {
233
+ return null;
234
+ }
235
+
236
+ const filteredChildren = node.children ? filterNodes(node.children) : undefined;
237
+ const result: FileNode | null = filteredChildren && filteredChildren.length > 0
238
+ ? { ...node, children: filteredChildren }
239
+ : filteredChildren === undefined && matchesSearch && matchesType
240
+ ? { ...node }
241
+ : null;
242
+ return result;
243
+ })
244
+ .filter((node): node is FileNode => node !== null);
245
+ };
246
+
247
+ const filteredFiles = useMemo(() => {
248
+ if (!searchQuery && fileTypeFilter === 'all') return files;
249
+ return filterNodes(files);
250
+ }, [files, searchQuery, fileTypeFilter]);
251
+
252
+ // Count total files
253
+ const countFiles = (nodes: FileNode[]): number => {
254
+ let count = 0;
255
+ nodes.forEach(node => {
256
+ if (node.type === 'file') count++;
257
+ if (node.children) count += countFiles(node.children);
258
+ });
259
+ return count;
260
+ };
261
+
262
+ const totalFileCount = useMemo(() => countFiles(files), [files]);
263
+ const visibleFileCount = useMemo(() => countFiles(filteredFiles), [filteredFiles]);
264
+
265
+ // Keyboard shortcut for search (Cmd+K / Ctrl+K)
266
+ useEffect(() => {
267
+ const handleKeyDown = (e: KeyboardEvent) => {
268
+ if ((e.metaKey || e.ctrlKey) && e.key === 'k') {
269
+ e.preventDefault();
270
+ setShowSearch(true);
271
+ setTimeout(() => searchInputRef.current?.focus(), 0);
272
+ }
273
+ if (e.key === 'Escape' && showSearch) {
274
+ setShowSearch(false);
275
+ setSearchQuery('');
276
+ }
277
+ };
278
+ window.addEventListener('keydown', handleKeyDown);
279
+ return () => window.removeEventListener('keydown', handleKeyDown);
280
+ }, [showSearch]);
281
+
282
  const getFileIcon = (node: FileNode): string => {
283
  if (node.type === 'directory') {
284
  return expandedPaths.has(node.path) ? 'πŸ“‚' : 'πŸ“';
 
301
  return iconMap[ext || ''] || 'πŸ“„';
302
  };
303
 
304
+ const copyFilePath = (path: string) => {
305
+ navigator.clipboard.writeText(path).then(() => {
306
+ // Show temporary feedback
307
+ const button = document.querySelector(`[data-file-path="${path}"]`) as HTMLElement;
308
+ if (button) {
309
+ const originalText = button.textContent;
310
+ button.textContent = 'Copied!';
311
+ setTimeout(() => {
312
+ if (button) button.textContent = originalText;
313
+ }, 1000);
314
+ }
315
+ });
316
+ };
317
+
318
+ const getFileUrl = (path: string) => {
319
+ return `https://huggingface.co/${modelId}/resolve/main/${path}`;
320
+ };
321
+
322
  const renderNode = (node: FileNode, depth: number = 0): React.ReactNode => {
323
  const isExpanded = expandedPaths.has(node.path);
324
  const hasChildren = node.children && node.children.length > 0;
325
+ const fileName = node.path.split('/').pop() || node.path;
326
 
327
  return (
328
  <div key={node.path} className="file-tree-node" style={{ paddingLeft: `${depth * 1.5}rem` }}>
 
332
  style={{ cursor: node.type === 'directory' ? 'pointer' : 'default' }}
333
  >
334
  <span className="file-icon">{getFileIcon(node)}</span>
335
+ <span className="file-name" title={node.path}>{fileName}</span>
336
  {node.type === 'file' && node.size && (
337
  <span className="file-size">{formatFileSize(node.size)}</span>
338
  )}
339
  {node.type === 'directory' && (
340
  <span className="file-expand">{isExpanded ? 'β–Ό' : 'β–Ά'}</span>
341
  )}
342
+ {node.type === 'file' && (
343
+ <div className="file-actions" onClick={(e) => e.stopPropagation()}>
344
+ <button
345
+ className="file-action-btn"
346
+ onClick={() => copyFilePath(node.path)}
347
+ data-file-path={node.path}
348
+ title="Copy file path"
349
+ aria-label="Copy path"
350
+ >
351
+ πŸ“‹
352
+ </button>
353
+ <a
354
+ href={getFileUrl(node.path)}
355
+ target="_blank"
356
+ rel="noopener noreferrer"
357
+ className="file-action-btn"
358
+ title="Download file"
359
+ aria-label="Download"
360
+ onClick={(e) => e.stopPropagation()}
361
+ >
362
+ ⬇️
363
+ </a>
364
+ </div>
365
+ )}
366
  </div>
367
  {isExpanded && hasChildren && (
368
  <div className="file-tree-children">
 
402
  );
403
  }
404
 
405
+ const hasDirectories = files.some(node => node.type === 'directory');
406
+
407
  return (
408
  <div className="file-tree-container">
409
  <div className="file-tree-header">
410
+ <div style={{ display: 'flex', alignItems: 'center', gap: '0.5rem' }}>
411
+ <strong>Repository Files</strong>
412
+ <span className="file-count-badge">
413
+ {visibleFileCount === totalFileCount
414
+ ? `${totalFileCount} file${totalFileCount !== 1 ? 's' : ''}`
415
+ : `${visibleFileCount} of ${totalFileCount} files`}
416
+ </span>
417
+ </div>
418
+ <div style={{ display: 'flex', gap: '0.5rem', alignItems: 'center', flexWrap: 'wrap' }}>
419
+ <button
420
+ onClick={() => setShowSearch(!showSearch)}
421
+ className="file-tree-button"
422
+ title="Search files (Cmd+K)"
423
+ aria-label="Search"
424
+ >
425
+ πŸ” Search
426
+ </button>
427
+ {hasDirectories && (
428
+ <>
429
+ <button
430
+ onClick={expandAll}
431
+ className="file-tree-button"
432
+ title="Expand all directories"
433
+ aria-label="Expand all"
434
+ >
435
+ Expand All
436
+ </button>
437
+ <button
438
+ onClick={collapseAll}
439
+ className="file-tree-button"
440
+ title="Collapse all directories"
441
+ aria-label="Collapse all"
442
+ >
443
+ Collapse All
444
+ </button>
445
+ </>
446
+ )}
447
+ <a
448
+ href={getHuggingFaceFileTreeUrl(modelId, 'main')}
449
+ target="_blank"
450
+ rel="noopener noreferrer"
451
+ className="file-tree-link"
452
+ >
453
+ View on HF β†’
454
+ </a>
455
+ </div>
456
  </div>
457
+
458
+ {/* Search and Filter Bar */}
459
+ {(showSearch || searchQuery || fileTypeFilter !== 'all') && (
460
+ <div className="file-tree-filters">
461
+ <div className="file-tree-search">
462
+ <input
463
+ ref={searchInputRef}
464
+ type="text"
465
+ placeholder="Search files... (Cmd+K)"
466
+ value={searchQuery}
467
+ onChange={(e) => setSearchQuery(e.target.value)}
468
+ className="file-tree-search-input"
469
+ />
470
+ {searchQuery && (
471
+ <button
472
+ onClick={() => setSearchQuery('')}
473
+ className="file-tree-clear"
474
+ aria-label="Clear search"
475
+ >
476
+ βœ•
477
+ </button>
478
+ )}
479
+ </div>
480
+ {getAllFileExtensions.length > 0 && (
481
+ <select
482
+ value={fileTypeFilter}
483
+ onChange={(e) => setFileTypeFilter(e.target.value)}
484
+ className="file-tree-type-filter"
485
+ >
486
+ <option value="all">All file types</option>
487
+ {getAllFileExtensions.map(ext => (
488
+ <option key={ext} value={ext}>.{ext}</option>
489
+ ))}
490
+ </select>
491
+ )}
492
+ </div>
493
+ )}
494
+
495
  <div className="file-tree">
496
+ {filteredFiles.length === 0 ? (
497
+ <div className="file-tree-empty">
498
+ {searchQuery || fileTypeFilter !== 'all'
499
+ ? 'No files match your filters'
500
+ : 'No files found'}
501
+ </div>
502
+ ) : (
503
+ filteredFiles.map((node) => renderNode(node))
504
+ )}
505
  </div>
506
  </div>
507
  );
frontend/src/components/{ModelModal.css β†’ modals/ModelModal.css} RENAMED
@@ -25,7 +25,7 @@
25
  .modal-content {
26
  background: #ffffff;
27
  border-radius: 8px;
28
- max-width: 800px;
29
  width: 100%;
30
  max-height: 90vh;
31
  overflow-y: auto;
@@ -34,11 +34,16 @@
34
  box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
35
  border: 1px solid #d0d0d0;
36
  animation: slideUp 0.3s ease-out;
37
- font-family: 'Vend Sans', sans-serif;
38
  display: flex;
39
  flex-direction: column;
40
  }
41
 
 
 
 
 
 
42
  @keyframes slideUp {
43
  from {
44
  transform: translateY(20px);
@@ -66,7 +71,7 @@
66
  justify-content: center;
67
  border-radius: 2px;
68
  transition: all 0.2s;
69
- font-family: 'Vend Sans', sans-serif;
70
  }
71
 
72
  .modal-close:hover {
@@ -89,7 +94,7 @@
89
  font-size: 1.5rem;
90
  color: #1a1a1a;
91
  word-break: break-word;
92
- font-family: 'Vend Sans', sans-serif;
93
  font-weight: 600;
94
  line-height: 1.3;
95
  }
@@ -117,7 +122,7 @@
117
  border-radius: 4px;
118
  cursor: pointer;
119
  font-size: 0.85rem;
120
- font-family: 'Vend Sans', sans-serif;
121
  transition: all 0.2s;
122
  font-weight: 500;
123
  }
@@ -136,6 +141,12 @@
136
  gap: 0.5rem;
137
  margin-bottom: 1.5rem;
138
  border-bottom: 2px solid #e0e0e0;
 
 
 
 
 
 
139
  }
140
 
141
  .modal-tab {
@@ -145,11 +156,29 @@
145
  border-bottom: 2px solid transparent;
146
  cursor: pointer;
147
  font-size: 0.9rem;
148
- font-family: 'Vend Sans', sans-serif;
149
  color: #666;
150
  font-weight: 500;
151
  margin-bottom: -2px;
152
  transition: all 0.2s;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  }
154
 
155
  .modal-tab:hover {
@@ -186,14 +215,14 @@
186
  text-transform: uppercase;
187
  letter-spacing: 0.5px;
188
  font-weight: 600;
189
- font-family: 'Vend Sans', sans-serif;
190
  }
191
 
192
  .info-value {
193
  font-size: 1.1rem;
194
  color: #1a1a1a;
195
  font-weight: 500;
196
- font-family: 'Vend Sans', sans-serif;
197
  }
198
 
199
  .info-value.highlight {
@@ -230,7 +259,7 @@
230
  letter-spacing: 0.5px;
231
  font-weight: 600;
232
  margin-bottom: 0.75rem;
233
- font-family: 'Vend Sans', sans-serif;
234
  }
235
 
236
  .section-content {
@@ -287,7 +316,7 @@
287
  color: #4a4a4a;
288
  text-transform: uppercase;
289
  letter-spacing: 0.5px;
290
- font-family: 'Vend Sans', sans-serif;
291
  }
292
 
293
  .modal-info-grid {
@@ -306,14 +335,14 @@
306
  font-size: 0.875rem;
307
  color: #6a6a6a;
308
  font-weight: 500;
309
- font-family: 'Vend Sans', sans-serif;
310
  }
311
 
312
  .modal-info-item span {
313
  font-size: 1rem;
314
  color: #1a1a1a;
315
  font-weight: 500;
316
- font-family: 'Vend Sans', sans-serif;
317
  }
318
 
319
  .modal-tags {
@@ -324,7 +353,7 @@
324
  color: #1a1a1a;
325
  font-size: 0.9rem;
326
  line-height: 1.5;
327
- font-family: 'Vend Sans', sans-serif;
328
  }
329
 
330
  .modal-footer {
@@ -345,7 +374,7 @@
345
  text-decoration: none;
346
  border-radius: 4px;
347
  font-weight: 500;
348
- font-family: 'Vend Sans', sans-serif;
349
  transition: all 0.2s;
350
  border: 1px solid #1a1a1a;
351
  }
 
25
  .modal-content {
26
  background: #ffffff;
27
  border-radius: 8px;
28
+ max-width: 900px;
29
  width: 100%;
30
  max-height: 90vh;
31
  overflow-y: auto;
 
34
  box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
35
  border: 1px solid #d0d0d0;
36
  animation: slideUp 0.3s ease-out;
37
+ font-family: 'Instrument Sans', sans-serif;
38
  display: flex;
39
  flex-direction: column;
40
  }
41
 
42
+ .modal-content[data-tab="files"] {
43
+ max-width: 1000px;
44
+ max-height: 95vh;
45
+ }
46
+
47
  @keyframes slideUp {
48
  from {
49
  transform: translateY(20px);
 
71
  justify-content: center;
72
  border-radius: 2px;
73
  transition: all 0.2s;
74
+ font-family: 'Instrument Sans', sans-serif;
75
  }
76
 
77
  .modal-close:hover {
 
94
  font-size: 1.5rem;
95
  color: #1a1a1a;
96
  word-break: break-word;
97
+ font-family: 'Instrument Sans', sans-serif;
98
  font-weight: 600;
99
  line-height: 1.3;
100
  }
 
122
  border-radius: 4px;
123
  cursor: pointer;
124
  font-size: 0.85rem;
125
+ font-family: 'Instrument Sans', sans-serif;
126
  transition: all 0.2s;
127
  font-weight: 500;
128
  }
 
141
  gap: 0.5rem;
142
  margin-bottom: 1.5rem;
143
  border-bottom: 2px solid #e0e0e0;
144
+ position: sticky;
145
+ top: 0;
146
+ background: #ffffff;
147
+ z-index: 10;
148
+ padding-top: 0.5rem;
149
+ margin-top: -0.5rem;
150
  }
151
 
152
  .modal-tab {
 
156
  border-bottom: 2px solid transparent;
157
  cursor: pointer;
158
  font-size: 0.9rem;
159
+ font-family: 'Instrument Sans', sans-serif;
160
  color: #666;
161
  font-weight: 500;
162
  margin-bottom: -2px;
163
  transition: all 0.2s;
164
+ display: flex;
165
+ align-items: center;
166
+ gap: 0.5rem;
167
+ position: relative;
168
+ }
169
+
170
+ .tab-icon {
171
+ font-size: 1rem;
172
+ }
173
+
174
+ .tab-badge {
175
+ background: #4a90e2;
176
+ color: white;
177
+ font-size: 0.7rem;
178
+ padding: 0.15rem 0.4rem;
179
+ border-radius: 10px;
180
+ font-weight: 600;
181
+ margin-left: 0.25rem;
182
  }
183
 
184
  .modal-tab:hover {
 
215
  text-transform: uppercase;
216
  letter-spacing: 0.5px;
217
  font-weight: 600;
218
+ font-family: 'Instrument Sans', sans-serif;
219
  }
220
 
221
  .info-value {
222
  font-size: 1.1rem;
223
  color: #1a1a1a;
224
  font-weight: 500;
225
+ font-family: 'Instrument Sans', sans-serif;
226
  }
227
 
228
  .info-value.highlight {
 
259
  letter-spacing: 0.5px;
260
  font-weight: 600;
261
  margin-bottom: 0.75rem;
262
+ font-family: 'Instrument Sans', sans-serif;
263
  }
264
 
265
  .section-content {
 
316
  color: #4a4a4a;
317
  text-transform: uppercase;
318
  letter-spacing: 0.5px;
319
+ font-family: 'Instrument Sans', sans-serif;
320
  }
321
 
322
  .modal-info-grid {
 
335
  font-size: 0.875rem;
336
  color: #6a6a6a;
337
  font-weight: 500;
338
+ font-family: 'Instrument Sans', sans-serif;
339
  }
340
 
341
  .modal-info-item span {
342
  font-size: 1rem;
343
  color: #1a1a1a;
344
  font-weight: 500;
345
+ font-family: 'Instrument Sans', sans-serif;
346
  }
347
 
348
  .modal-tags {
 
353
  color: #1a1a1a;
354
  font-size: 0.9rem;
355
  line-height: 1.5;
356
+ font-family: 'Instrument Sans', sans-serif;
357
  }
358
 
359
  .modal-footer {
 
374
  text-decoration: none;
375
  border-radius: 4px;
376
  font-weight: 500;
377
+ font-family: 'Instrument Sans', sans-serif;
378
  transition: all 0.2s;
379
  border: 1px solid #1a1a1a;
380
  }
frontend/src/components/{ModelModal.tsx β†’ modals/ModelModal.tsx} RENAMED
@@ -3,12 +3,12 @@
3
  * Enhanced with bookmark, comparison, similar models, and file tree features.
4
  */
5
  import React, { useState, useEffect } from 'react';
6
- import { ModelPoint } from '../types';
7
  import FileTree from './FileTree';
 
 
8
  import './ModelModal.css';
9
 
10
- const API_BASE = process.env.REACT_APP_API_URL || 'http://localhost:8000';
11
-
12
  interface ArxivPaper {
13
  arxiv_id: string;
14
  title: string;
@@ -71,7 +71,7 @@ export default function ModelModal({
71
 
72
  if (!isOpen || !model) return null;
73
 
74
- const hfUrl = `https://huggingface.co/${model.model_id}`;
75
 
76
  // Parse tags if it's a string representation of an array
77
  const parseTags = (tags: string | null | undefined): string[] => {
@@ -156,7 +156,11 @@ export default function ModelModal({
156
 
157
  return (
158
  <div className="modal-overlay" onClick={onClose}>
159
- <div className="modal-content" onClick={(e) => e.stopPropagation()}>
 
 
 
 
160
  <div className="modal-header">
161
  <h2>{model.model_id}</h2>
162
  <button className="modal-close" onClick={onClose}>Close</button>
@@ -197,20 +201,24 @@ export default function ModelModal({
197
  className={`modal-tab ${activeTab === 'details' ? 'active' : ''}`}
198
  onClick={() => setActiveTab('details')}
199
  >
200
- Details
 
201
  </button>
202
  <button
203
  className={`modal-tab ${activeTab === 'files' ? 'active' : ''}`}
204
  onClick={() => setActiveTab('files')}
205
  >
206
- Files
 
207
  </button>
208
  {(papers.length > 0 || papersLoading) && (
209
  <button
210
  className={`modal-tab ${activeTab === 'papers' ? 'active' : ''}`}
211
  onClick={() => setActiveTab('papers')}
212
  >
213
- Papers {papers.length > 0 && `(${papers.length})`}
 
 
214
  </button>
215
  )}
216
  </div>
@@ -283,7 +291,7 @@ export default function ModelModal({
283
  <div className="section-title">Parent Model</div>
284
  <div className="section-content">
285
  <a
286
- href={`https://huggingface.co/${model.parent_model}`}
287
  target="_blank"
288
  rel="noopener noreferrer"
289
  className="model-link"
 
3
  * Enhanced with bookmark, comparison, similar models, and file tree features.
4
  */
5
  import React, { useState, useEffect } from 'react';
6
+ import { ModelPoint } from '../../types';
7
  import FileTree from './FileTree';
8
+ import { getHuggingFaceUrl } from '../../utils/api/hfUrl';
9
+ import { API_BASE } from '../../config/api';
10
  import './ModelModal.css';
11
 
 
 
12
  interface ArxivPaper {
13
  arxiv_id: string;
14
  title: string;
 
71
 
72
  if (!isOpen || !model) return null;
73
 
74
+ const hfUrl = getHuggingFaceUrl(model.model_id);
75
 
76
  // Parse tags if it's a string representation of an array
77
  const parseTags = (tags: string | null | undefined): string[] => {
 
156
 
157
  return (
158
  <div className="modal-overlay" onClick={onClose}>
159
+ <div
160
+ className="modal-content"
161
+ onClick={(e) => e.stopPropagation()}
162
+ data-tab={activeTab}
163
+ >
164
  <div className="modal-header">
165
  <h2>{model.model_id}</h2>
166
  <button className="modal-close" onClick={onClose}>Close</button>
 
201
  className={`modal-tab ${activeTab === 'details' ? 'active' : ''}`}
202
  onClick={() => setActiveTab('details')}
203
  >
204
+ <span className="tab-icon">πŸ“‹</span>
205
+ <span>Details</span>
206
  </button>
207
  <button
208
  className={`modal-tab ${activeTab === 'files' ? 'active' : ''}`}
209
  onClick={() => setActiveTab('files')}
210
  >
211
+ <span className="tab-icon">πŸ“</span>
212
+ <span>Files</span>
213
  </button>
214
  {(papers.length > 0 || papersLoading) && (
215
  <button
216
  className={`modal-tab ${activeTab === 'papers' ? 'active' : ''}`}
217
  onClick={() => setActiveTab('papers')}
218
  >
219
+ <span className="tab-icon">πŸ“„</span>
220
+ <span>Papers</span>
221
+ {papers.length > 0 && <span className="tab-badge">{papers.length}</span>}
222
  </button>
223
  )}
224
  </div>
 
291
  <div className="section-title">Parent Model</div>
292
  <div className="section-content">
293
  <a
294
+ href={getHuggingFaceUrl(model.parent_model)}
295
  target="_blank"
296
  rel="noopener noreferrer"
297
  className="model-link"
frontend/src/components/{ColorLegend.css β†’ ui/ColorLegend.css} RENAMED
File without changes
frontend/src/components/{ColorLegend.tsx β†’ ui/ColorLegend.tsx} RENAMED
@@ -3,7 +3,7 @@
3
  * Shows color mappings for categorical and continuous data.
4
  */
5
  import React from 'react';
6
- import { getCategoricalColorMap, getContinuousColorScale } from '../utils/colors';
7
  import './ColorLegend.css';
8
 
9
  interface ColorLegendProps {
 
3
  * Shows color mappings for categorical and continuous data.
4
  */
5
  import React from 'react';
6
+ import { getCategoricalColorMap, getContinuousColorScale } from '../../utils/rendering/colors';
7
  import './ColorLegend.css';
8
 
9
  interface ColorLegendProps {
frontend/src/components/{ErrorBoundary.tsx β†’ ui/ErrorBoundary.tsx} RENAMED
File without changes