Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException, Query | |
| import uvicorn | |
| from pydantic import BaseModel | |
| import requests | |
| from bs4 import BeautifulSoup as bs | |
| import mysql.connector | |
| import os | |
| import google.generativeai as genai | |
| import json | |
| from util.keywordExtract import * | |
| from typing import Optional,List, Dict, Any | |
| import pandas as pd | |
| import torch | |
| import pandas as pd | |
| from io import StringIO # pandas.read_html에 문자열을 전달할 때 필요 | |
| import logging # 로깅을 위해 추가 | |
| import time # 요청 간 지연을 위해 추가 (선택 사항이지만 권장) | |
| from embedding_module import embed_keywords | |
| from keyword_module import summarize_kobart as summarize, extract_keywords | |
| from pykrx import stock | |
| from functools import lru_cache | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import traceback | |
| from datetime import datetime, timedelta | |
| from googletrans import Translator | |
| from starlette.concurrency import run_in_threadpool | |
| import FinanceDataReader as fdr | |
| app = FastAPI() | |
| # 로깅 설정 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not API_KEY: | |
| # API 키가 없으면 에러를 발생시키거나 경고 | |
| print("❌ GEMINI_API_KEY 환경 변수가 설정되지 않았습니다.") | |
| else: | |
| genai.configure(api_key=API_KEY) | |
| logger.info("✅ Gemini API 설정 완료 (환경 변수 사용)") | |
| class NewsRequest(BaseModel): | |
| url: str | |
| id: Optional[str] = None | |
| # 🧠 학습 모델 구조 정의 | |
| class SimpleClassifier(torch.nn.Module): | |
| def __init__(self, input_dim): | |
| super().__init__() | |
| self.net = torch.nn.Sequential( | |
| torch.nn.Linear(input_dim, 64), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(64, 1), | |
| torch.nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| def fetch_html(url): | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| response = requests.get(url, headers=headers, timeout=5) | |
| response.raise_for_status() | |
| return bs(response.text, "html.parser") | |
| def parse_naver(soup): | |
| title = soup.select_one("h2.media_end_head_headline") or soup.title | |
| title_text = title.get_text(strip=True) if title else "제목 없음" | |
| time_tag = soup.select_one("span.media_end_head_info_datestamp_time") | |
| time_text = time_tag.get_text(strip=True) if time_tag else "시간 없음" | |
| content_area = soup.find("div", {"id": "newsct_article"}) or soup.find("div", {"id": "dic_area"}) | |
| if content_area: | |
| paragraphs = content_area.find_all("p") | |
| content = '\n'.join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else content_area.get_text(strip=True) | |
| else: | |
| content = "본문 없음" | |
| return title_text, time_text, content | |
| def parse_daum(soup): | |
| title = soup.select_one("h3.tit_view") or soup.title | |
| title_text = title.get_text(strip=True) if title else "제목 없음" | |
| time_tag = soup.select_one("span.num_date") | |
| time_text = time_tag.get_text(strip=True) if time_tag else "시간 없음" | |
| content_area = soup.find("div", {"class": "article_view"}) | |
| if content_area: | |
| paragraphs = content_area.find_all("p") | |
| content = '\n'.join([p.get_text(strip=True) for p in paragraphs]) if paragraphs else content_area.get_text(strip=True) | |
| else: | |
| content = "본문 없음" | |
| return title_text, time_text, content | |
| def extract_thumbnail(soup): | |
| tag = soup.find("meta", property="og:image") | |
| return tag["content"] if tag and "content" in tag.attrs else None | |
| def gemini_use(resultK): | |
| generation_config = genai.GenerationConfig( | |
| temperature=1, | |
| response_mime_type=None # 그냥 문자열로 응답받기 | |
| ) | |
| model = genai.GenerativeModel('gemini-2.0-flash', generation_config=generation_config) | |
| prompt = f""" | |
| 아래 내용을 참고해서 가장 연관성이 높은 주식 상장 회사 이름 하나만 말해줘. | |
| 다른 설명 없이 회사 이름만 대답해. | |
| "{resultK}" | |
| """ | |
| response = model.generate_content(prompt) | |
| try: | |
| result_text = response.text.strip() | |
| except AttributeError: | |
| result_text = response.candidates[0].content.parts[0].text.strip() | |
| return result_text | |
| def parse_news(req: NewsRequest): | |
| url = req.url.strip() | |
| username = req.id.strip() if req.id else None | |
| try: | |
| soup = fetch_html(url) | |
| if "naver.com" in url: | |
| title, time, content = parse_naver(soup) | |
| elif "daum.net" in url: | |
| title, time, content = parse_daum(soup) | |
| else: | |
| raise HTTPException(status_code=400, detail="지원하지 않는 뉴스 사이트입니다.") | |
| thumbnail_url = extract_thumbnail(soup) | |
| resultK = resultKeyword(content) | |
| sumce = classify_emotion(content) | |
| targetCompany = gemini_use(resultK) | |
| sentiment = analyze_sentiment(content) | |
| pos_percent = int(sentiment["positive"] * 100) | |
| neg_percent = int(sentiment["negative"] * 100) | |
| sentiment_result = { | |
| "positive": pos_percent, | |
| "negative": neg_percent | |
| } | |
| summary = summarize(content) | |
| print(summary) | |
| _, keywords_2nd = extract_keywords(summary) | |
| clean_keywords = [kw for kw, _ in keywords_2nd] | |
| keyword_vec = embed_keywords(clean_keywords) | |
| input_vec = torch.tensor(keyword_vec, dtype=torch.float32).unsqueeze(0) # (1, D) | |
| input_dim = input_vec.shape[1] | |
| model = SimpleClassifier(input_dim) | |
| model.load_state_dict(torch.load("news_model.pt", map_location="cpu")) | |
| model.eval() | |
| with torch.no_grad(): | |
| prob = model(input_vec).item() | |
| prediction = int(prob >= 0.5) | |
| prediction = '📈 상승 (1)' if prediction == 1 else '📉 하락 (0)' | |
| print(type(prob)) | |
| print(type(prediction)) | |
| return { | |
| "message": "뉴스 파싱 및 저장 완료", | |
| "title": title, | |
| "time": time, | |
| "content": content, | |
| "thumbnail_url": thumbnail_url, | |
| "url": url, | |
| "summary": resultK["summary"], | |
| "keyword": resultK["keyword"], | |
| "company": targetCompany, | |
| "sentiment": sumce, | |
| "sentiment_value": sentiment_result, | |
| "prediction": prediction, | |
| "prob": prob, | |
| } | |
| except requests.exceptions.RequestException as e: | |
| traceback.print_exc() # 전체 스택트레이스 콘솔에 출력 | |
| raise HTTPException(status_code=500, detail=f"서버 오류: {e}") | |
| except Exception as e: | |
| traceback.print_exc() # 전체 스택트레이스 콘솔에 출력 | |
| raise HTTPException(status_code=500, detail=f"서버 오류: {e}") | |
| from fastapi.concurrency import run_in_threadpool # 동기 함수를 비동기처럼 실행하기 위해 | |
| from typing import List, Dict, Any # 반환 타입 명시를 위해 (선택 사항) | |
| # --- 전역 변수 (서버 시작 시 초기화) --- | |
| krx_listings: pd.DataFrame = None | |
| us_listings: pd.DataFrame = None | |
| translator: Translator = None | |
| # --- 서버 시작 시 실행될 로직 --- | |
| async def load_initial_data(): | |
| """ | |
| 서버가 시작될 때 주식 목록과 번역기를 미리 로드하여 | |
| API 요청마다 반복적으로 로드하는 것을 방지합니다. | |
| """ | |
| global krx_listings, us_listings, translator | |
| logger.info("✅ 서버 시작: 초기 데이터 로딩을 시작합니다...") | |
| try: | |
| krx_listings = await run_in_threadpool(fdr.StockListing, 'KRX') | |
| logger.info("📊 한국 상장 기업 목록 로딩 완료.") | |
| nasdaq = await run_in_threadpool(fdr.StockListing, 'NASDAQ') | |
| nyse = await run_in_threadpool(fdr.StockListing, 'NYSE') | |
| amex = await run_in_threadpool(fdr.StockListing, 'AMEX') | |
| us_listings = pd.concat([nasdaq, nyse, amex], ignore_index=True) | |
| logger.info("📊 미국 상장 기업 목록 로딩 완료.") | |
| translator = Translator() | |
| logger.info("🌐 번역기 초기화 완료.") | |
| logger.info("✅ 모든 초기 데이터 로딩이 성공적으로 완료되었습니다.") | |
| except Exception as e: | |
| logger.error(f"🚨 초기 데이터 로딩 중 심각한 오류 발생: {e}", exc_info=True) | |
| # 필요하다면 여기서 서버 실행을 중단시킬 수도 있습니다. | |
| # raise RuntimeError("Failed to load initial stock listings.") from e | |
| # --- 핵심 로직 함수 --- | |
| def get_stock_info(company_name: str) -> Dict[str, str] | None: | |
| """ | |
| 회사명을 받아 한국 또는 미국 시장에서 종목 정보를 찾아 반환합니다. | |
| (정상 동작하는 스크립트의 로직을 그대로 적용) | |
| """ | |
| # 1. 한국 주식에서 먼저 검색 | |
| kr_match = krx_listings[krx_listings['Name'].str.contains(company_name, case=False, na=False)] | |
| if not kr_match.empty: | |
| stock = kr_match.iloc[0] | |
| logger.info(f"KRX에서 '{company_name}' 발견: {stock['Name']} ({stock['Code']})") | |
| return {"market": "KRX", "symbol": stock['Code'], "name": stock['Name']} | |
| # 2. 한국에 없으면 미국 주식에서 검색 (번역기 사용) | |
| try: | |
| # 번역은 I/O 작업이므로 스레드풀에서 실행하는 것이 더 안전할 수 있으나, | |
| # googletrans의 내부 구현에 따라 여기서 직접 호출해도 큰 문제가 없을 수 있습니다. | |
| company_name_eng = translator.translate(company_name, src='ko', dest='en').text | |
| logger.info(f"'{company_name}' -> 영어로 번역: '{company_name_eng}'") | |
| # 이름 또는 심볼에서 검색 | |
| us_match = us_listings[ | |
| us_listings['Name'].str.contains(company_name_eng, case=False, na=False) | | |
| us_listings['Symbol'].str.fullmatch(company_name_eng, case=False) | |
| ] | |
| if not us_match.empty: | |
| stock = us_match.iloc[0] | |
| logger.info(f"US에서 '{company_name}' 발견: {stock['Name']} ({stock['Symbol']})") | |
| return {"market": "US", "symbol": stock['Symbol'], "name": stock['Name']} | |
| except Exception as e: | |
| logger.error(f"'{company_name}' 번역 또는 미국 주식 검색 중 오류: {e}") | |
| # 3. 최종적으로 찾지 못한 경우 | |
| logger.warning(f"'{company_name}'에 해당하는 종목을 찾지 못했습니다.") | |
| return None | |
| def fetch_stock_prices_sync(symbol: str, days: int = 365) -> pd.DataFrame: | |
| """ | |
| 지정된 기간 동안의 주가 데이터를 가져옵니다 (동기 함수). | |
| """ | |
| end_date = datetime.today() | |
| start_date = end_date - timedelta(days=days) | |
| logger.info(f"FinanceDataReader로 '{symbol}'의 주가 데이터 조회를 시작합니다 ({start_date.date()} ~ {end_date.date()}).") | |
| try: | |
| df = fdr.DataReader(symbol, start=start_date, end=end_date) | |
| if df.empty: | |
| logger.warning(f"'{symbol}'에 대한 데이터가 없습니다.") | |
| return None | |
| return df | |
| except Exception as e: | |
| logger.error(f"'{symbol}' 데이터 조회 중 오류 발생: {e}", exc_info=True) | |
| return None | |
| # --- API 엔드포인트 --- | |
| async def get_stock_data_by_name( | |
| company_name: str = Query(..., description="조회할 회사명") | |
| ) -> List[Dict[str, Any]]: | |
| if not company_name or not company_name.strip(): | |
| raise HTTPException(status_code=400, detail="회사명을 입력해주세요.") | |
| stock_info = await run_in_threadpool(get_stock_info, company_name.strip()) | |
| if not stock_info: | |
| raise HTTPException(status_code=404, detail=f"'{company_name}'에 해당하는 종목을 찾을 수 없습니다.") | |
| prices_df = await run_in_threadpool(fetch_stock_prices_sync, stock_info['symbol'], 365) | |
| if prices_df is None or prices_df.empty: | |
| raise HTTPException(status_code=404, detail=f"'{stock_info['name']}'의 시세 데이터를 찾을 수 없습니다.") | |
| prices_df.index.name = 'Date' # 👈 이 줄을 추가하여 인덱스 이름을 명시적으로 설정 | |
| prices_df.reset_index(inplace=True) | |
| prices_df['Date'] = prices_df['Date'].dt.strftime('%Y-%m-%d') | |
| return prices_df.to_dict(orient='records') | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |