# -*- coding: utf-8 -*- """ 以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索 数据源:image_embeddings 表(统一向量存储) """ import os import uuid import json from flask import Blueprint, request, jsonify from sqlalchemy import text from app.extensions import db from app.utils.ai_vision import load_clip_model, get_image_embedding from app.models.inbound.buy import StockBuy from app.models.inbound.semi import StockSemi from app.models.inbound.product import StockProduct from app.models.base import MaterialBase # 注册蓝图 image_search_bp = Blueprint('image_search', __name__) # ============================================================================ # POST /api/v1/common/image-search # 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索 # ============================================================================ @image_search_bp.route('/image-search', methods=['POST']) def image_search(): # --------------------------------------------------------- # 1. 检查文件 # --------------------------------------------------------- if 'file' not in request.files: return jsonify({"code": 400, "msg": "未找到图片文件"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"code": 400, "msg": "未选择文件"}), 400 # --------------------------------------------------------- # 2. 安全保存临时文件 # --------------------------------------------------------- ext = file.filename.rsplit('.', 1)[-1].lower() if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}: return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400 tmp_filename = f"{uuid.uuid4().hex}.{ext}" tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads') os.makedirs(tmp_dir, exist_ok=True) tmp_path = os.path.join(tmp_dir, tmp_filename) try: file.save(tmp_path) print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}") # --------------------------------------------------------- # 3. 提取 CLIP embedding # --------------------------------------------------------- load_clip_model() embedding = get_image_embedding(tmp_path) print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}") except Exception as e: print(f"❌ [ImageSearch] 图像处理失败: {e}") return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500 finally: # --------------------------------------------------------- # 4. 无论成功与否,都删除临时文件 # --------------------------------------------------------- if os.path.exists(tmp_path): try: os.remove(tmp_path) print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}") except Exception as e: print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}") # --------------------------------------------------------- # 5. pgvector 余弦相似度检索(统一查 image_embeddings 表) # --------------------------------------------------------- try: query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']' sql = text(""" SELECT ie.id AS embedding_id, ie.module_name, ie.target_id, ie.image_url, (1 - (ie.embedding <=> :query_vector)) AS similarity FROM image_embeddings ie WHERE ie.embedding IS NOT NULL ORDER BY ie.embedding <=> :query_vector LIMIT 200 """) raw_records = db.session.execute(sql, {"query_vector": query_vector_str}).fetchall() if not raw_records: return jsonify({"code": 200, "data": []}) # 按 (module_name, target_id) 去重,每业务记录只保留最相似的那张图 seen = {} for row in raw_records: key = (row.module_name, row.target_id) if key not in seen: seen[key] = row # 批量回填业务数据 target_ids_by_module = {} for row in seen.values(): target_ids_by_module.setdefault(row.module_name, []).append(row.target_id) business_map = {} # 回填 StockBuy if 'stock_buy' in target_ids_by_module: ids = target_ids_by_module['stock_buy'] records = ( db.session.query(StockBuy) .filter(StockBuy.id.in_(ids)) .outerjoin(MaterialBase, StockBuy.base_id == MaterialBase.id) .all() ) for r in records: business_map[('stock_buy', r.id)] = { 'record_id': r.id, 'name': r.base.name if r.base else None, 'spec_model': r.base.spec_model if r.base else None, 'sku': r.sku, 'barcode': r.barcode, 'serial_number': r.serial_number, 'batch_number': r.batch_number, 'status': r.status, 'warehouse_location': r.warehouse_location, 'stock_quantity': r.stock_quantity, 'module_name': 'stock_buy', 'url': '/inventory/buy', } # 回填 StockSemi if 'stock_semi' in target_ids_by_module: ids = target_ids_by_module['stock_semi'] records = ( db.session.query(StockSemi) .filter(StockSemi.id.in_(ids)) .outerjoin(MaterialBase, StockSemi.base_id == MaterialBase.id) .all() ) for r in records: business_map[('stock_semi', r.id)] = { 'record_id': r.id, 'name': r.base.name if r.base else None, 'spec_model': r.base.spec_model if r.base else None, 'sku': r.sku, 'barcode': r.barcode, 'serial_number': r.serial_number, 'batch_number': r.batch_number, 'status': r.status, 'warehouse_location': r.warehouse_location, 'stock_quantity': r.stock_quantity, 'module_name': 'stock_semi', 'url': '/inventory/semi', } # 回填 StockProduct if 'stock_product' in target_ids_by_module: ids = target_ids_by_module['stock_product'] records = ( db.session.query(StockProduct) .filter(StockProduct.id.in_(ids)) .outerjoin(MaterialBase, StockProduct.base_id == MaterialBase.id) .all() ) for r in records: business_map[('stock_product', r.id)] = { 'record_id': r.id, 'name': r.base.name if r.base else None, 'spec_model': r.base.spec_model if r.base else None, 'sku': r.sku, 'barcode': r.barcode, 'serial_number': r.serial_number, 'batch_number': r.batch_number, 'status': r.status, 'warehouse_location': r.warehouse_location, 'stock_quantity': r.stock_quantity, 'sale_price': r.sale_price, 'module_name': 'stock_product', 'url': '/inventory/product', } # 回填 MaterialBase if 'material_base' in target_ids_by_module: ids = target_ids_by_module['material_base'] records = MaterialBase.query.filter(MaterialBase.id.in_(ids)).all() for r in records: business_map[('material_base', r.id)] = { 'record_id': r.id, 'name': r.name, 'spec_model': r.spec_model, 'common_name': r.common_name, 'category': r.category, 'material_type': r.material_type, 'unit': r.unit, 'module_name': 'material_base', 'url': '/material/index', } # 组装最终返回 results = [] for row in seen.values(): key = (row.module_name, row.target_id) biz = business_map.get(key, {}) raw_url = row.image_url or '' clean_url = raw_url if raw_url.startswith('['): try: url_list = json.loads(raw_url) clean_url = url_list[0] if url_list else '' except: pass results.append({ "module_name": row.module_name, "target_id": row.target_id, "image_url": clean_url, "similarity": round(float(row.similarity), 4), "product_name": biz.get('name') or biz.get('material_name') or '未命名物料', "product_id": row.target_id, "spec_model": biz.get('spec_model') or '', "business_data": biz, }) if len(results) >= 10: break return jsonify({"code": 200, "data": results}) except Exception as e: print(f"❌ [ImageSearch] 数据库检索失败: {e}") return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500