mike23415 commited on
Commit
d12a697
·
verified ·
1 Parent(s): aed8107

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +400 -0
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ import base64
4
+ import io
5
+ import os
6
+ from PIL import Image
7
+ import logging
8
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
9
+ import torch
10
+ import easyocr
11
+ import numpy as np
12
+ import threading
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ app = Flask(__name__)
19
+ CORS(app)
20
+
21
+ # Global variables for models
22
+ trocr_processor = None
23
+ trocr_model = None
24
+ easyocr_reader = None
25
+ models_loaded = False
26
+ loading_lock = threading.Lock()
27
+
28
+ def initialize_models():
29
+ """Initialize OCR models"""
30
+ global trocr_processor, trocr_model, easyocr_reader, models_loaded
31
+
32
+ if models_loaded:
33
+ return
34
+
35
+ with loading_lock:
36
+ if models_loaded: # Double-check after acquiring lock
37
+ return
38
+
39
+ try:
40
+ logger.info("Starting model initialization...")
41
+
42
+ # Set cache directory
43
+ cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/app/.cache/huggingface')
44
+ os.makedirs(cache_dir, exist_ok=True)
45
+
46
+ # Initialize TrOCR for handwritten text (Microsoft's model)
47
+ logger.info("Loading TrOCR model for handwritten text...")
48
+ trocr_processor = TrOCRProcessor.from_pretrained(
49
+ "microsoft/trocr-base-handwritten",
50
+ cache_dir=cache_dir
51
+ )
52
+ trocr_model = VisionEncoderDecoderModel.from_pretrained(
53
+ "microsoft/trocr-base-handwritten",
54
+ cache_dir=cache_dir
55
+ )
56
+
57
+ # Initialize EasyOCR for printed text
58
+ logger.info("Loading EasyOCR for printed text...")
59
+ easyocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
60
+
61
+ models_loaded = True
62
+ logger.info("All models loaded successfully!")
63
+
64
+ except Exception as e:
65
+ logger.error(f"Error loading models: {str(e)}")
66
+ models_loaded = False
67
+ raise e
68
+
69
+ def ensure_models_loaded():
70
+ """Ensure models are loaded before processing"""
71
+ if not models_loaded:
72
+ initialize_models()
73
+
74
+ def preprocess_image(image):
75
+ """Preprocess image for better OCR results"""
76
+ # Convert to RGB if needed
77
+ if image.mode != 'RGB':
78
+ image = image.convert('RGB')
79
+
80
+ # Resize if image is too large
81
+ max_size = 1024
82
+ if max(image.size) > max_size:
83
+ ratio = max_size / max(image.size)
84
+ new_size = tuple(int(dim * ratio) for dim in image.size)
85
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
86
+
87
+ return image
88
+
89
+ def extract_text_trocr(image):
90
+ """Extract text using TrOCR (good for handwritten text)"""
91
+ try:
92
+ ensure_models_loaded()
93
+ if not trocr_processor or not trocr_model:
94
+ return ""
95
+
96
+ # Preprocess image
97
+ image = preprocess_image(image)
98
+
99
+ # Generate pixel values
100
+ pixel_values = trocr_processor(image, return_tensors="pt").pixel_values
101
+
102
+ # Generate text
103
+ generated_ids = trocr_model.generate(pixel_values)
104
+ generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
105
+
106
+ return generated_text.strip()
107
+ except Exception as e:
108
+ logger.error(f"TrOCR error: {str(e)}")
109
+ return ""
110
+
111
+ def extract_text_easyocr(image):
112
+ """Extract text using EasyOCR (good for printed text)"""
113
+ try:
114
+ ensure_models_loaded()
115
+ if not easyocr_reader:
116
+ return ""
117
+
118
+ # Convert PIL image to numpy array
119
+ image_np = np.array(preprocess_image(image))
120
+
121
+ # Extract text
122
+ results = easyocr_reader.readtext(image_np, detail=0)
123
+
124
+ # Join all detected text
125
+ extracted_text = ' '.join(results)
126
+ return extracted_text.strip()
127
+ except Exception as e:
128
+ logger.error(f"EasyOCR error: {str(e)}")
129
+ return ""
130
+
131
+ def process_image_ocr(image, ocr_type="auto"):
132
+ """Process image with specified OCR method"""
133
+ results = {}
134
+
135
+ if ocr_type in ["auto", "handwritten", "trocr"]:
136
+ trocr_text = extract_text_trocr(image)
137
+ results["trocr"] = trocr_text
138
+
139
+ if ocr_type in ["auto", "printed", "easyocr"]:
140
+ easyocr_text = extract_text_easyocr(image)
141
+ results["easyocr"] = easyocr_text
142
+
143
+ # For auto mode, return the longer result or combine both
144
+ if ocr_type == "auto":
145
+ trocr_len = len(results.get("trocr", ""))
146
+ easyocr_len = len(results.get("easyocr", ""))
147
+
148
+ if trocr_len > 0 and easyocr_len > 0:
149
+ # If both have results, combine them intelligently
150
+ if abs(trocr_len - easyocr_len) / max(trocr_len, easyocr_len) < 0.3:
151
+ # If lengths are similar, prefer EasyOCR for printed text
152
+ results["final"] = results["easyocr"]
153
+ else:
154
+ # Use the longer result
155
+ results["final"] = results["trocr"] if trocr_len > easyocr_len else results["easyocr"]
156
+ elif trocr_len > 0:
157
+ results["final"] = results["trocr"]
158
+ elif easyocr_len > 0:
159
+ results["final"] = results["easyocr"]
160
+ else:
161
+ results["final"] = ""
162
+ else:
163
+ # Return the specific model result
164
+ results["final"] = results.get(ocr_type.replace("handwritten", "trocr").replace("printed", "easyocr"), "")
165
+
166
+ return results
167
+
168
+ @app.route('/')
169
+ def home():
170
+ """Root endpoint"""
171
+ return jsonify({
172
+ "service": "OCR Backend",
173
+ "status": "running",
174
+ "version": "1.0.0",
175
+ "models_loaded": models_loaded,
176
+ "endpoints": {
177
+ "health": "/health",
178
+ "ocr": "/ocr (POST)",
179
+ "batch_ocr": "/ocr/batch (POST)",
180
+ "models_info": "/models/info (GET)"
181
+ },
182
+ "supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"],
183
+ "ocr_types": ["auto", "handwritten", "printed"]
184
+ })
185
+
186
+ @app.route('/health', methods=['GET'])
187
+ def health_check():
188
+ """Health check endpoint"""
189
+ return jsonify({
190
+ "status": "healthy",
191
+ "models_loaded": models_loaded,
192
+ "service": "OCR Backend"
193
+ })
194
+
195
+ @app.route('/ocr', methods=['POST'])
196
+ def ocr_endpoint():
197
+ """Main OCR endpoint"""
198
+ try:
199
+ # Ensure models are loaded
200
+ ensure_models_loaded()
201
+
202
+ # Check if image is provided
203
+ if 'image' not in request.files and not request.is_json:
204
+ return jsonify({"error": "No image provided. Use 'image' field for file upload or JSON with 'image_base64'"}), 400
205
+
206
+ if request.is_json and 'image_base64' not in request.json:
207
+ return jsonify({"error": "No 'image_base64' field found in JSON"}), 400
208
+
209
+ # Get OCR type preference
210
+ if request.is_json:
211
+ ocr_type = request.json.get('type', 'auto')
212
+ else:
213
+ ocr_type = request.form.get('type', 'auto')
214
+
215
+ # Validate ocr_type
216
+ if ocr_type not in ['auto', 'handwritten', 'printed', 'trocr', 'easyocr']:
217
+ return jsonify({"error": "Invalid OCR type. Use: auto, handwritten, printed"}), 400
218
+
219
+ # Load image
220
+ if 'image' in request.files:
221
+ # File upload
222
+ image_file = request.files['image']
223
+ if image_file.filename == '':
224
+ return jsonify({"error": "No file selected"}), 400
225
+ image = Image.open(image_file.stream)
226
+ else:
227
+ # Base64 image
228
+ image_data = request.json['image_base64']
229
+ if image_data.startswith('data:image'):
230
+ # Remove data URL prefix
231
+ image_data = image_data.split(',')[1]
232
+
233
+ try:
234
+ # Decode base64
235
+ image_bytes = base64.b64decode(image_data)
236
+ image = Image.open(io.BytesIO(image_bytes))
237
+ except Exception as e:
238
+ return jsonify({"error": f"Invalid base64 image data: {str(e)}"}), 400
239
+
240
+ # Process image
241
+ results = process_image_ocr(image, ocr_type)
242
+
243
+ response = {
244
+ "success": True,
245
+ "text": results["final"],
246
+ "type_used": ocr_type,
247
+ "character_count": len(results["final"]),
248
+ "details": {
249
+ "trocr_result": results.get("trocr", ""),
250
+ "easyocr_result": results.get("easyocr", "")
251
+ } if ocr_type == "auto" else {}
252
+ }
253
+
254
+ return jsonify(response)
255
+
256
+ except Exception as e:
257
+ logger.error(f"OCR processing error: {str(e)}")
258
+ return jsonify({"error": str(e), "success": False}), 500
259
+
260
+ @app.route('/ocr/batch', methods=['POST'])
261
+ def batch_ocr_endpoint():
262
+ """Batch OCR endpoint for multiple images"""
263
+ try:
264
+ # Ensure models are loaded
265
+ ensure_models_loaded()
266
+
267
+ if 'images' not in request.files:
268
+ return jsonify({"error": "No images provided. Use 'images' field for multiple file upload"}), 400
269
+
270
+ images = request.files.getlist('images')
271
+ if not images or len(images) == 0:
272
+ return jsonify({"error": "No images found in request"}), 400
273
+
274
+ ocr_type = request.form.get('type', 'auto')
275
+
276
+ # Validate ocr_type
277
+ if ocr_type not in ['auto', 'handwritten', 'printed', 'trocr', 'easyocr']:
278
+ return jsonify({"error": "Invalid OCR type. Use: auto, handwritten, printed"}), 400
279
+
280
+ results = []
281
+ for i, image_file in enumerate(images):
282
+ try:
283
+ if image_file.filename == '':
284
+ results.append({
285
+ "index": i,
286
+ "filename": "empty_file",
287
+ "error": "Empty filename",
288
+ "success": False
289
+ })
290
+ continue
291
+
292
+ image = Image.open(image_file.stream)
293
+ ocr_results = process_image_ocr(image, ocr_type)
294
+
295
+ results.append({
296
+ "index": i,
297
+ "filename": image_file.filename,
298
+ "text": ocr_results["final"],
299
+ "character_count": len(ocr_results["final"]),
300
+ "success": True
301
+ })
302
+ except Exception as e:
303
+ results.append({
304
+ "index": i,
305
+ "filename": image_file.filename if hasattr(image_file, 'filename') else f"image_{i}",
306
+ "error": str(e),
307
+ "success": False
308
+ })
309
+
310
+ successful_count = sum(1 for r in results if r["success"])
311
+
312
+ return jsonify({
313
+ "success": True,
314
+ "results": results,
315
+ "total_processed": len(results),
316
+ "successful": successful_count,
317
+ "failed": len(results) - successful_count,
318
+ "type_used": ocr_type
319
+ })
320
+
321
+ except Exception as e:
322
+ logger.error(f"Batch OCR error: {str(e)}")
323
+ return jsonify({"error": str(e), "success": False}), 500
324
+
325
+ @app.route('/models/info', methods=['GET'])
326
+ def models_info():
327
+ """Get information about loaded models"""
328
+ return jsonify({
329
+ "models": {
330
+ "trocr": {
331
+ "name": "microsoft/trocr-base-handwritten",
332
+ "description": "Handwritten text recognition using Transformer-based OCR",
333
+ "loaded": trocr_model is not None and trocr_processor is not None,
334
+ "best_for": "Handwritten text, notes, forms"
335
+ },
336
+ "easyocr": {
337
+ "name": "EasyOCR",
338
+ "description": "Printed text recognition with CRAFT + CRNN",
339
+ "loaded": easyocr_reader is not None,
340
+ "best_for": "Printed text, documents, signs, books"
341
+ }
342
+ },
343
+ "supported_types": ["auto", "handwritten", "printed"],
344
+ "supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"],
345
+ "cache_directory": os.environ.get('TRANSFORMERS_CACHE', '/app/.cache/huggingface'),
346
+ "gpu_available": torch.cuda.is_available(),
347
+ "models_loaded": models_loaded
348
+ })
349
+
350
+ @app.route('/models/load', methods=['POST'])
351
+ def load_models():
352
+ """Manually trigger model loading"""
353
+ try:
354
+ if models_loaded:
355
+ return jsonify({"message": "Models already loaded", "success": True})
356
+
357
+ initialize_models()
358
+ return jsonify({"message": "Models loaded successfully", "success": True})
359
+ except Exception as e:
360
+ return jsonify({"error": str(e), "success": False}), 500
361
+
362
+ @app.errorhandler(404)
363
+ def not_found(error):
364
+ return jsonify({
365
+ "error": "Endpoint not found",
366
+ "available_endpoints": {
367
+ "GET /": "Service information",
368
+ "GET /health": "Health check",
369
+ "POST /ocr": "Single image OCR",
370
+ "POST /ocr/batch": "Batch image OCR",
371
+ "GET /models/info": "Model information",
372
+ "POST /models/load": "Load models manually"
373
+ }
374
+ }), 404
375
+
376
+ @app.errorhandler(500)
377
+ def internal_error(error):
378
+ return jsonify({
379
+ "error": "Internal server error",
380
+ "message": "Please check the server logs for more details"
381
+ }), 500
382
+
383
+ # Initialize models when running with gunicorn
384
+ if __name__ != '__main__':
385
+ logger.info("Starting OCR service with gunicorn...")
386
+ # Don't initialize models here - let them load lazily on first request
387
+ # This prevents startup failures due to model loading issues
388
+
389
+ if __name__ == '__main__':
390
+ logger.info("Starting OCR service in development mode...")
391
+ try:
392
+ # Try to initialize models, but don't fail if it doesn't work
393
+ initialize_models()
394
+ except Exception as e:
395
+ logger.warning(f"Could not initialize models on startup: {e}")
396
+ logger.info("Models will be loaded on first request")
397
+
398
+ # Run the app
399
+ port = int(os.environ.get('PORT', 5000))
400
+ app.run(host='0.0.0.0', port=port, debug=False)