import os import gc import numpy as np import cv2 from PIL import Image, ImageEnhance import logging import base64 import io import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel from flask import Flask, request, jsonify from flask_cors import CORS import warnings warnings.filterwarnings('ignore') # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) CORS(app) # Global variables for TrOCR processor = None model = None models_loaded = False device = "cuda" if torch.cuda.is_available() else "cpu" def initialize_trocr(): """Initialize TrOCR model - works on Hugging Face without system dependencies""" global processor, model, models_loaded if models_loaded: return try: logger.info("Loading TrOCR model...") # Use the smaller, faster model for free tier model_name = "microsoft/trocr-base-printed" # Initialize processor and model processor = TrOCRProcessor.from_pretrained(model_name) model = VisionEncoderDecoderModel.from_pretrained(model_name) # Move to device model = model.to(device) model.eval() # Set to evaluation mode models_loaded = True logger.info(f"TrOCR model loaded successfully on {device}") except Exception as e: logger.error(f"Error loading TrOCR: {str(e)}") models_loaded = False raise e def preprocess_image_simple(image): """Simple image preprocessing for TrOCR""" try: # Convert to PIL Image if needed if isinstance(image, np.ndarray): if len(image.shape) == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = Image.fromarray(image) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Resize if too large (TrOCR works best with reasonable sizes) max_size = 1024 if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = tuple(int(dim * ratio) for dim in image.size) image = image.resize(new_size, Image.Resampling.LANCZOS) # Enhance image quality # Increase contrast slightly enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(1.2) # Increase sharpness slightly enhancer = ImageEnhance.Sharpness(image) image = enhancer.enhance(1.1) return image except Exception as e: logger.error(f"Preprocessing error: {e}") return image def extract_text_trocr(image): """Extract text using TrOCR""" try: if not models_loaded: initialize_trocr() # Preprocess image processed_image = preprocess_image_simple(image) # Prepare inputs pixel_values = processor(processed_image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) # Generate text with torch.no_grad(): generated_ids = model.generate( pixel_values, max_length=512, num_beams=4, early_stopping=True ) # Decode the generated text generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Clean up text cleaned_text = generated_text.strip() # Calculate a confidence score based on text length and quality confidence = min(0.9, len(cleaned_text) / 100) if cleaned_text else 0.0 return { 'text': cleaned_text, 'confidence': confidence, 'word_count': len(cleaned_text.split()) if cleaned_text else 0 } except Exception as e: logger.error(f"TrOCR error: {e}") return {'text': '', 'confidence': 0.0, 'word_count': 0} def process_image_with_enhancement(image, enhancement_type="default"): """Process image with different enhancement levels""" try: # Convert to PIL if needed if isinstance(image, np.ndarray): if len(image.shape) == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = Image.fromarray(image) if enhancement_type == "enhance": # More aggressive enhancement for poor quality images # Increase contrast more enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(1.5) # Increase brightness slightly enhancer = ImageEnhance.Brightness(image) image = enhancer.enhance(1.1) # Increase sharpness more enhancer = ImageEnhance.Sharpness(image) image = enhancer.enhance(1.3) elif enhancement_type == "binary": # Convert to grayscale and apply threshold gray = image.convert('L') # Simple threshold threshold = 128 binary = gray.point(lambda x: 255 if x > threshold else 0, mode='1') image = binary.convert('RGB') # Extract text using TrOCR result = extract_text_trocr(image) result['enhancement'] = enhancement_type return result except Exception as e: logger.error(f"Enhancement processing error: {e}") return {'text': '', 'confidence': 0.0, 'word_count': 0, 'enhancement': enhancement_type} @app.route('/') def home(): """Root endpoint""" return jsonify({ "service": "TrOCR OCR Service", "status": "running", "version": "1.0.0", "engine": "TrOCR (Transformers)", "model": "microsoft/trocr-base-printed", "device": device, "description": "Hugging Face compatible OCR service using TrOCR", "endpoints": { "health": "/health", "ocr": "/ocr (POST)", "batch_ocr": "/ocr/batch (POST)" }, "supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"], "enhancement_types": ["default", "enhance", "binary"], "features": [ "No system dependencies required", "Transformer-based OCR", "Works on Hugging Face Spaces", "GPU acceleration when available", "Memory efficient" ] }) @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint""" try: return jsonify({ "status": "healthy", "models_loaded": models_loaded, "device": device, "torch_version": torch.__version__, "service": "TrOCR OCR Service" }) except Exception as e: return jsonify({ "status": "error", "error": str(e) }), 500 @app.route('/ocr', methods=['POST']) def ocr_endpoint(): """Main OCR endpoint using TrOCR""" try: logger.info("OCR request received") # Ensure models are loaded if not models_loaded: initialize_trocr() # Check if image is provided if 'image' not in request.files and not request.is_json: return jsonify({"error": "No image provided"}), 400 # Get parameters if request.is_json: enhancement = request.json.get('enhancement', 'default') else: enhancement = request.form.get('enhancement', 'default') # Validate enhancement type valid_enhancements = ['default', 'enhance', 'binary'] if enhancement not in valid_enhancements: return jsonify({"error": f"Invalid enhancement. Use: {', '.join(valid_enhancements)}"}), 400 # Load image try: if 'image' in request.files: image_file = request.files['image'] if image_file.filename == '': return jsonify({"error": "No file selected"}), 400 image_data = image_file.read() image = Image.open(io.BytesIO(image_data)) else: image_data = request.json['image_base64'] if image_data.startswith('data:image'): image_data = image_data.split(',')[1] image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)) except Exception as e: return jsonify({"error": f"Invalid image: {str(e)}"}), 400 # Process image logger.info("Starting TrOCR processing") result = process_image_with_enhancement(image, enhancement) # Clean up del image if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() logger.info(f"OCR completed. Text length: {len(result['text'])}, Confidence: {result['confidence']:.2f}") response = { "success": True, "text": result['text'], "confidence": round(result['confidence'], 3), "character_count": len(result['text']), "word_count": result.get('word_count', 0), "enhancement_used": result.get('enhancement', 'unknown'), "engine": "TrOCR", "model": "microsoft/trocr-base-printed", "device": device } return jsonify(response) except Exception as e: logger.error(f"OCR processing error: {str(e)}") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return jsonify({"error": str(e), "success": False}), 500 @app.route('/ocr/batch', methods=['POST']) def batch_ocr_endpoint(): """Batch OCR endpoint""" try: logger.info("Batch OCR request received") if not models_loaded: initialize_trocr() if 'images' not in request.files: return jsonify({"error": "No images provided"}), 400 images = request.files.getlist('images') if not images: return jsonify({"error": "No images found"}), 400 # Limit batch size for free tier max_batch_size = 3 if len(images) > max_batch_size: return jsonify({"error": f"Maximum {max_batch_size} images allowed"}), 400 enhancement = request.form.get('enhancement', 'default') results = [] for i, image_file in enumerate(images): try: logger.info(f"Processing image {i+1}/{len(images)}") if image_file.filename == '': results.append({ "index": i, "filename": "empty_file", "error": "Empty filename", "success": False }) continue image_data = image_file.read() image = Image.open(io.BytesIO(image_data)) # Process with TrOCR result = process_image_with_enhancement(image, enhancement) results.append({ "index": i, "filename": image_file.filename, "text": result['text'], "confidence": round(result['confidence'], 3), "character_count": len(result['text']), "word_count": result.get('word_count', 0), "success": True }) # Clean up del image if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() except Exception as e: logger.error(f"Error processing image {i}: {str(e)}") results.append({ "index": i, "filename": image_file.filename if hasattr(image_file, 'filename') else f"image_{i}", "error": str(e), "success": False }) successful_count = sum(1 for r in results if r["success"]) return jsonify({ "success": True, "results": results, "total_processed": len(results), "successful": successful_count, "failed": len(results) - successful_count, "enhancement_used": enhancement, "engine": "TrOCR", "device": device }) except Exception as e: logger.error(f"Batch OCR error: {str(e)}") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return jsonify({"error": str(e), "success": False}), 500 @app.route('/models/load', methods=['POST']) def load_models(): """Manually load TrOCR models""" try: if models_loaded: return jsonify({"message": "TrOCR already loaded", "success": True}) initialize_trocr() return jsonify({"message": "TrOCR loaded successfully", "success": True, "device": device}) except Exception as e: return jsonify({"error": str(e), "success": False}), 500 @app.errorhandler(404) def not_found(error): return jsonify({ "error": "Endpoint not found", "available_endpoints": { "GET /": "Service information", "GET /health": "Health check", "POST /ocr": "Single image OCR", "POST /ocr/batch": "Batch image OCR", "POST /models/load": "Load models manually" } }), 404 @app.errorhandler(500) def internal_error(error): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return jsonify({ "error": "Internal server error", "message": "Please check server logs" }), 500 if __name__ == '__main__': logger.info("Starting TrOCR OCR service...") port = int(os.environ.get('PORT', 7860)) # Hugging Face Spaces uses port 7860 app.run(host='0.0.0.0', port=port, debug=False, threaded=True)