|
|
import os |
|
|
import gc |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image, ImageEnhance |
|
|
import logging |
|
|
import base64 |
|
|
import io |
|
|
import torch |
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
processor = None |
|
|
model = None |
|
|
models_loaded = False |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def initialize_trocr(): |
|
|
"""Initialize TrOCR model - works on Hugging Face without system dependencies""" |
|
|
global processor, model, models_loaded |
|
|
|
|
|
if models_loaded: |
|
|
return |
|
|
|
|
|
try: |
|
|
logger.info("Loading TrOCR model...") |
|
|
|
|
|
|
|
|
model_name = "microsoft/trocr-base-printed" |
|
|
|
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained(model_name) |
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
models_loaded = True |
|
|
logger.info(f"TrOCR model loaded successfully on {device}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading TrOCR: {str(e)}") |
|
|
models_loaded = False |
|
|
raise e |
|
|
|
|
|
def preprocess_image_simple(image): |
|
|
"""Simple image preprocessing for TrOCR""" |
|
|
try: |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
if len(image.shape) == 3: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
|
|
|
max_size = 1024 |
|
|
if max(image.size) > max_size: |
|
|
ratio = max_size / max(image.size) |
|
|
new_size = tuple(int(dim * ratio) for dim in image.size) |
|
|
image = image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
|
|
|
enhancer = ImageEnhance.Contrast(image) |
|
|
image = enhancer.enhance(1.2) |
|
|
|
|
|
|
|
|
enhancer = ImageEnhance.Sharpness(image) |
|
|
image = enhancer.enhance(1.1) |
|
|
|
|
|
return image |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Preprocessing error: {e}") |
|
|
return image |
|
|
|
|
|
def extract_text_trocr(image): |
|
|
"""Extract text using TrOCR""" |
|
|
try: |
|
|
if not models_loaded: |
|
|
initialize_trocr() |
|
|
|
|
|
|
|
|
processed_image = preprocess_image_simple(image) |
|
|
|
|
|
|
|
|
pixel_values = processor(processed_image, return_tensors="pt").pixel_values |
|
|
pixel_values = pixel_values.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
pixel_values, |
|
|
max_length=512, |
|
|
num_beams=4, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
|
|
|
cleaned_text = generated_text.strip() |
|
|
|
|
|
|
|
|
confidence = min(0.9, len(cleaned_text) / 100) if cleaned_text else 0.0 |
|
|
|
|
|
return { |
|
|
'text': cleaned_text, |
|
|
'confidence': confidence, |
|
|
'word_count': len(cleaned_text.split()) if cleaned_text else 0 |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"TrOCR error: {e}") |
|
|
return {'text': '', 'confidence': 0.0, 'word_count': 0} |
|
|
|
|
|
def process_image_with_enhancement(image, enhancement_type="default"): |
|
|
"""Process image with different enhancement levels""" |
|
|
try: |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
if len(image.shape) == 3: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
if enhancement_type == "enhance": |
|
|
|
|
|
|
|
|
enhancer = ImageEnhance.Contrast(image) |
|
|
image = enhancer.enhance(1.5) |
|
|
|
|
|
|
|
|
enhancer = ImageEnhance.Brightness(image) |
|
|
image = enhancer.enhance(1.1) |
|
|
|
|
|
|
|
|
enhancer = ImageEnhance.Sharpness(image) |
|
|
image = enhancer.enhance(1.3) |
|
|
|
|
|
elif enhancement_type == "binary": |
|
|
|
|
|
gray = image.convert('L') |
|
|
|
|
|
threshold = 128 |
|
|
binary = gray.point(lambda x: 255 if x > threshold else 0, mode='1') |
|
|
image = binary.convert('RGB') |
|
|
|
|
|
|
|
|
result = extract_text_trocr(image) |
|
|
result['enhancement'] = enhancement_type |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Enhancement processing error: {e}") |
|
|
return {'text': '', 'confidence': 0.0, 'word_count': 0, 'enhancement': enhancement_type} |
|
|
|
|
|
@app.route('/') |
|
|
def home(): |
|
|
"""Root endpoint""" |
|
|
return jsonify({ |
|
|
"service": "TrOCR OCR Service", |
|
|
"status": "running", |
|
|
"version": "1.0.0", |
|
|
"engine": "TrOCR (Transformers)", |
|
|
"model": "microsoft/trocr-base-printed", |
|
|
"device": device, |
|
|
"description": "Hugging Face compatible OCR service using TrOCR", |
|
|
"endpoints": { |
|
|
"health": "/health", |
|
|
"ocr": "/ocr (POST)", |
|
|
"batch_ocr": "/ocr/batch (POST)" |
|
|
}, |
|
|
"supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"], |
|
|
"enhancement_types": ["default", "enhance", "binary"], |
|
|
"features": [ |
|
|
"No system dependencies required", |
|
|
"Transformer-based OCR", |
|
|
"Works on Hugging Face Spaces", |
|
|
"GPU acceleration when available", |
|
|
"Memory efficient" |
|
|
] |
|
|
}) |
|
|
|
|
|
@app.route('/health', methods=['GET']) |
|
|
def health_check(): |
|
|
"""Health check endpoint""" |
|
|
try: |
|
|
return jsonify({ |
|
|
"status": "healthy", |
|
|
"models_loaded": models_loaded, |
|
|
"device": device, |
|
|
"torch_version": torch.__version__, |
|
|
"service": "TrOCR OCR Service" |
|
|
}) |
|
|
except Exception as e: |
|
|
return jsonify({ |
|
|
"status": "error", |
|
|
"error": str(e) |
|
|
}), 500 |
|
|
|
|
|
@app.route('/ocr', methods=['POST']) |
|
|
def ocr_endpoint(): |
|
|
"""Main OCR endpoint using TrOCR""" |
|
|
try: |
|
|
logger.info("OCR request received") |
|
|
|
|
|
|
|
|
if not models_loaded: |
|
|
initialize_trocr() |
|
|
|
|
|
|
|
|
if 'image' not in request.files and not request.is_json: |
|
|
return jsonify({"error": "No image provided"}), 400 |
|
|
|
|
|
|
|
|
if request.is_json: |
|
|
enhancement = request.json.get('enhancement', 'default') |
|
|
else: |
|
|
enhancement = request.form.get('enhancement', 'default') |
|
|
|
|
|
|
|
|
valid_enhancements = ['default', 'enhance', 'binary'] |
|
|
if enhancement not in valid_enhancements: |
|
|
return jsonify({"error": f"Invalid enhancement. Use: {', '.join(valid_enhancements)}"}), 400 |
|
|
|
|
|
|
|
|
try: |
|
|
if 'image' in request.files: |
|
|
image_file = request.files['image'] |
|
|
if image_file.filename == '': |
|
|
return jsonify({"error": "No file selected"}), 400 |
|
|
|
|
|
image_data = image_file.read() |
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
else: |
|
|
image_data = request.json['image_base64'] |
|
|
if image_data.startswith('data:image'): |
|
|
image_data = image_data.split(',')[1] |
|
|
|
|
|
image_bytes = base64.b64decode(image_data) |
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({"error": f"Invalid image: {str(e)}"}), 400 |
|
|
|
|
|
|
|
|
logger.info("Starting TrOCR processing") |
|
|
result = process_image_with_enhancement(image, enhancement) |
|
|
|
|
|
|
|
|
del image |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
logger.info(f"OCR completed. Text length: {len(result['text'])}, Confidence: {result['confidence']:.2f}") |
|
|
|
|
|
response = { |
|
|
"success": True, |
|
|
"text": result['text'], |
|
|
"confidence": round(result['confidence'], 3), |
|
|
"character_count": len(result['text']), |
|
|
"word_count": result.get('word_count', 0), |
|
|
"enhancement_used": result.get('enhancement', 'unknown'), |
|
|
"engine": "TrOCR", |
|
|
"model": "microsoft/trocr-base-printed", |
|
|
"device": device |
|
|
} |
|
|
|
|
|
return jsonify(response) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"OCR processing error: {str(e)}") |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
return jsonify({"error": str(e), "success": False}), 500 |
|
|
|
|
|
@app.route('/ocr/batch', methods=['POST']) |
|
|
def batch_ocr_endpoint(): |
|
|
"""Batch OCR endpoint""" |
|
|
try: |
|
|
logger.info("Batch OCR request received") |
|
|
|
|
|
if not models_loaded: |
|
|
initialize_trocr() |
|
|
|
|
|
if 'images' not in request.files: |
|
|
return jsonify({"error": "No images provided"}), 400 |
|
|
|
|
|
images = request.files.getlist('images') |
|
|
if not images: |
|
|
return jsonify({"error": "No images found"}), 400 |
|
|
|
|
|
|
|
|
max_batch_size = 3 |
|
|
if len(images) > max_batch_size: |
|
|
return jsonify({"error": f"Maximum {max_batch_size} images allowed"}), 400 |
|
|
|
|
|
enhancement = request.form.get('enhancement', 'default') |
|
|
|
|
|
results = [] |
|
|
for i, image_file in enumerate(images): |
|
|
try: |
|
|
logger.info(f"Processing image {i+1}/{len(images)}") |
|
|
|
|
|
if image_file.filename == '': |
|
|
results.append({ |
|
|
"index": i, |
|
|
"filename": "empty_file", |
|
|
"error": "Empty filename", |
|
|
"success": False |
|
|
}) |
|
|
continue |
|
|
|
|
|
image_data = image_file.read() |
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
|
|
|
|
|
|
result = process_image_with_enhancement(image, enhancement) |
|
|
|
|
|
results.append({ |
|
|
"index": i, |
|
|
"filename": image_file.filename, |
|
|
"text": result['text'], |
|
|
"confidence": round(result['confidence'], 3), |
|
|
"character_count": len(result['text']), |
|
|
"word_count": result.get('word_count', 0), |
|
|
"success": True |
|
|
}) |
|
|
|
|
|
|
|
|
del image |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing image {i}: {str(e)}") |
|
|
results.append({ |
|
|
"index": i, |
|
|
"filename": image_file.filename if hasattr(image_file, 'filename') else f"image_{i}", |
|
|
"error": str(e), |
|
|
"success": False |
|
|
}) |
|
|
|
|
|
successful_count = sum(1 for r in results if r["success"]) |
|
|
|
|
|
return jsonify({ |
|
|
"success": True, |
|
|
"results": results, |
|
|
"total_processed": len(results), |
|
|
"successful": successful_count, |
|
|
"failed": len(results) - successful_count, |
|
|
"enhancement_used": enhancement, |
|
|
"engine": "TrOCR", |
|
|
"device": device |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Batch OCR error: {str(e)}") |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
return jsonify({"error": str(e), "success": False}), 500 |
|
|
|
|
|
@app.route('/models/load', methods=['POST']) |
|
|
def load_models(): |
|
|
"""Manually load TrOCR models""" |
|
|
try: |
|
|
if models_loaded: |
|
|
return jsonify({"message": "TrOCR already loaded", "success": True}) |
|
|
|
|
|
initialize_trocr() |
|
|
return jsonify({"message": "TrOCR loaded successfully", "success": True, "device": device}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e), "success": False}), 500 |
|
|
|
|
|
@app.errorhandler(404) |
|
|
def not_found(error): |
|
|
return jsonify({ |
|
|
"error": "Endpoint not found", |
|
|
"available_endpoints": { |
|
|
"GET /": "Service information", |
|
|
"GET /health": "Health check", |
|
|
"POST /ocr": "Single image OCR", |
|
|
"POST /ocr/batch": "Batch image OCR", |
|
|
"POST /models/load": "Load models manually" |
|
|
} |
|
|
}), 404 |
|
|
|
|
|
@app.errorhandler(500) |
|
|
def internal_error(error): |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
return jsonify({ |
|
|
"error": "Internal server error", |
|
|
"message": "Please check server logs" |
|
|
}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
logger.info("Starting TrOCR OCR service...") |
|
|
port = int(os.environ.get('PORT', 7860)) |
|
|
app.run(host='0.0.0.0', port=port, debug=False, threaded=True) |