import os import io import logging import tempfile from flask import Flask, request, jsonify from werkzeug.utils import secure_filename from PyPDF2 import PdfReader from docx import Document from pptx import Presentation from transformers import pipeline from flask_cors import CORS import re # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize Flask app app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*"}}) # Set up a temporary directory for Hugging Face cache cache_dir = tempfile.mkdtemp() os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir # Load lighter model for summarization logger.info("Loading summarization model...") try: summarizer = pipeline("summarization", model="facebook/bart-large-cnn", cache_dir=cache_dir) logger.info("Summarization model loaded successfully.") except Exception as e: logger.error(f"Failed to load model: {str(e)}") summarizer = None # Store document content in memory (in production, use proper database) document_store = {} ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"} def allowed_file(filename): """Check if the uploaded file has an allowed extension.""" return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS def chunk_text(text, chunk_size=1000): """Split text into chunks of approximately chunk_size characters.""" words = text.split() chunks = [] current_chunk = [] current_size = 0 for word in words: current_size += len(word) + 1 if current_size > chunk_size: chunks.append(" ".join(current_chunk)) current_chunk = [word] current_size = len(word) else: current_chunk.append(word) if current_chunk: chunks.append(" ".join(current_chunk)) return chunks def summarize_text(text, max_length=150, min_length=50): """Summarize text using BART.""" try: if not text.strip(): return "No text found in the document to summarize." if summarizer is None: return "Summarization model not available." # Split into chunks if text is too long chunks = chunk_text(text, chunk_size=1000) summaries = [] for i, chunk in enumerate(chunks[:5]): # Process max 5 chunks if len(chunk.split()) < 40: # Skip very short chunks continue try: summary = summarizer(chunk, max_length=max_length, min_length=min_length, do_sample=False) summaries.append(summary[0]['summary_text']) except Exception as e: logger.error(f"Error summarizing chunk {i}: {str(e)}") continue return " ".join(summaries) if summaries else "Unable to generate summary." except Exception as e: logger.error(f"Error in summarization: {str(e)}") return f"Error summarizing text: {str(e)}" def find_relevant_context(text, question, context_size=500): """Find the most relevant part of the text for the question.""" # Simple keyword matching question_words = set(re.findall(r'\w+', question.lower())) chunks = chunk_text(text, chunk_size=context_size) best_chunk = "" best_score = 0 for chunk in chunks: chunk_lower = chunk.lower() score = sum(1 for word in question_words if word in chunk_lower) if score > best_score: best_score = score best_chunk = chunk return best_chunk if best_chunk else chunks[0] if chunks else "" @app.route("/", methods=["GET"]) def index(): """Root endpoint.""" return jsonify({ "status": "running", "endpoints": ["/summarize", "/ask"] }) @app.route("/summarize", methods=["POST"]) def summarize(): """Summarize uploaded document.""" logger.info("Summarize endpoint called.") if "file" not in request.files: logger.error("No file found in request") return jsonify({"error": "No file uploaded"}), 400 file = request.files["file"] if file.filename == "": logger.error("File has no filename") return jsonify({"error": "No selected file"}), 400 if not allowed_file(file.filename): logger.error(f"Unsupported file format: {file.filename}") return jsonify({"error": f"Unsupported file format. Allowed: {', '.join(ALLOWED_EXTENSIONS)}"}), 400 filename = secure_filename(file.filename) file_content = file.read() file_ext = filename.rsplit(".", 1)[1].lower() try: # Extract text based on file type if file_ext == "pdf": text = extract_pdf(file_content) elif file_ext == "docx": text = extract_docx(file_content) elif file_ext == "pptx": text = extract_pptx(file_content) elif file_ext == "txt": text = extract_txt(file_content) else: return jsonify({"error": "Unsupported file format"}), 400 if not text.strip(): return jsonify({"error": "No text could be extracted from the document"}), 400 # Store document for later querying doc_id = filename document_store[doc_id] = text # Generate summary logger.info(f"Generating summary for {filename}") summary = summarize_text(text) logger.info(f"File {filename} processed successfully.") return jsonify({ "filename": filename, "summary": summary, "textLength": len(text), "documentId": doc_id }) except Exception as e: logger.error(f"Error processing file {filename}: {str(e)}") return jsonify({"error": f"Error processing file: {str(e)}"}), 500 @app.route("/ask", methods=["POST"]) def ask_question(): """Answer questions about uploaded document.""" logger.info("Ask endpoint called.") try: data = request.get_json() if not data: return jsonify({"error": "No JSON data provided"}), 400 question = data.get("question", "").strip() doc_id = data.get("documentId", "") if not question: return jsonify({"error": "No question provided"}), 400 if not doc_id or doc_id not in document_store: return jsonify({"error": "Document not found. Please upload a document first."}), 400 # Get document text doc_text = document_store[doc_id] # Find relevant context context = find_relevant_context(doc_text, question) # Generate answer based on context answer = generate_answer(context, question) return jsonify({ "answer": answer, "context": context[:200] + "..." if len(context) > 200 else context }) except Exception as e: logger.error(f"Error in ask endpoint: {str(e)}") return jsonify({"error": f"Error processing question: {str(e)}"}), 500 def generate_answer(context, question): """Generate an answer based on context and question.""" # Simple extractive approach question_lower = question.lower() # Look for sentences containing question keywords sentences = re.split(r'[.!?]+', context) question_words = set(re.findall(r'\w+', question_lower)) relevant_sentences = [] for sentence in sentences: sentence_lower = sentence.lower() score = sum(1 for word in question_words if word in sentence_lower) if score > 0: relevant_sentences.append((score, sentence.strip())) # Sort by relevance and return top sentences relevant_sentences.sort(reverse=True, key=lambda x: x[0]) if relevant_sentences: answer = ". ".join([s[1] for s in relevant_sentences[:3]]) return answer if answer else "I couldn't find a specific answer in the document." else: return "I couldn't find relevant information to answer this question in the document." def extract_pdf(file_content): """Extract text from PDF.""" try: reader = PdfReader(io.BytesIO(file_content)) text = "\n".join([page.extract_text() or "" for page in reader.pages]) return text.strip() except Exception as e: logger.error(f"Error extracting text from PDF: {str(e)}") raise Exception(f"Failed to extract text from PDF: {str(e)}") def extract_docx(file_content): """Extract text from DOCX.""" try: doc = Document(io.BytesIO(file_content)) text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()]) return text.strip() except Exception as e: logger.error(f"Error extracting text from DOCX: {str(e)}") raise Exception(f"Failed to extract text from DOCX: {str(e)}") def extract_pptx(file_content): """Extract text from PPTX.""" try: ppt = Presentation(io.BytesIO(file_content)) text = [] for slide in ppt.slides: for shape in slide.shapes: if hasattr(shape, "text") and shape.text.strip(): text.append(shape.text.strip()) return "\n".join(text).strip() except Exception as e: logger.error(f"Error extracting text from PPTX: {str(e)}") raise Exception(f"Failed to extract text from PPTX: {str(e)}") def extract_txt(file_content): """Extract text from TXT file.""" try: return file_content.decode("utf-8").strip() except UnicodeDecodeError: try: return file_content.decode("latin-1").strip() except Exception as e: logger.error(f"Error decoding text file: {str(e)}") raise Exception(f"Failed to decode text file: {str(e)}") if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=True)