Spaces:
Runtime error
Runtime error
File size: 1,961 Bytes
99bd2a2 055dd28 99bd2a2 5f2eff6 055dd28 5f2eff6 055dd28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# scripts/download_embedding_model.py
import os
from huggingface_hub import snapshot_download
# --- Configuration ---
# Read the HF token from environment variables, essential for private models.
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
# Central Hugging Face cache directory within the container.
HF_CACHE_DIR = os.getenv("HF_HOME", "/home/user/.cache/huggingface")
# Model repositories
EMBEDDING_MODEL_REPO = "abhinand/MedEmbed-large-v0.1"
LLM_REPO = "MedAI-COS30018/medalpaca-merge"
# Local directories where the final models will be stored for the app to use.
EMBEDDING_MODEL_DIR = "/app/embedding_model_cache"
LLM_DIR = "/app/llm_cache"
def download_model(repo_id: str, local_dir: str, token: str | None = None):
"""
Downloads a model from Hugging Face Hub to a specified local directory.
Args:
repo_id: The repository ID of the model on Hugging Face.
local_dir: The application-specific directory to copy the model to.
token: A Hugging Face token, required for private models.
"""
print(f"Downloading model: {repo_id}")
if not os.path.exists(local_dir):
os.makedirs(local_dir)
snapshot_download(
repo_id=repo_id,
cache_dir=HF_CACHE_DIR, # Use the central HF cache for downloads
local_dir=local_dir, # Copy final model files here
token=token,
)
print(f"Model '{repo_id}' downloaded to '{local_dir}'")
if __name__ == "__main__":
# Download the public embedding model (no token needed)
download_model(EMBEDDING_MODEL_REPO, EMBEDDING_MODEL_DIR)
if HUGGING_FACE_TOKEN:
print("HUGGING_FACE_HUB_TOKEN environment variable found.")
# WARNING: Do NOT print the full token. Just print the first few chars to confirm it's loaded.
print(f" Token starts with: '{HUGGING_FACE_TOKEN[:4]}...'")
# Download the private LLM (requires a token)
download_model(LLM_REPO, LLM_DIR, token=HUGGING_FACE_TOKEN)
else:
print("HUGGING_FACE_HUB_TOKEN environment variable NOT found.")
|