# -*- coding: utf-8 -*- """ 以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索 """ import os import uuid import json from flask import Blueprint, request, jsonify from sqlalchemy import text from app.extensions import db from app.utils.ai_vision import load_clip_model, get_image_embedding # 注册蓝图 image_search_bp = Blueprint('image_search', __name__) # ============================================================================ # POST /api/v1/common/image-search # 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索 # ============================================================================ @image_search_bp.route('/image-search', methods=['POST']) def image_search(): # --------------------------------------------------------- # 1. 检查文件 # --------------------------------------------------------- if 'file' not in request.files: return jsonify({"code": 400, "msg": "未找到图片文件"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"code": 400, "msg": "未选择文件"}), 400 # --------------------------------------------------------- # 2. 安全保存临时文件 # --------------------------------------------------------- ext = file.filename.rsplit('.', 1)[-1].lower() if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}: return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400 tmp_filename = f"{uuid.uuid4().hex}.{ext}" tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads') os.makedirs(tmp_dir, exist_ok=True) tmp_path = os.path.join(tmp_dir, tmp_filename) try: file.save(tmp_path) print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}") # --------------------------------------------------------- # 3. 提取 CLIP embedding # --------------------------------------------------------- load_clip_model() embedding = get_image_embedding(tmp_path) print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}") except Exception as e: print(f"❌ [ImageSearch] 图像处理失败: {e}") return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500 finally: # --------------------------------------------------------- # 4. 无论成功与否,都删除临时文件 # --------------------------------------------------------- if os.path.exists(tmp_path): try: os.remove(tmp_path) print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}") except Exception as e: print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}") # --------------------------------------------------------- # 5. pgvector 余弦相似度检索(跨表联合检索) # --------------------------------------------------------- try: query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']' sql = text(""" SELECT id, name, spec_model, image_url, (1 - (vec <=> :query_vector)) AS similarity FROM ( -- 1. 基础物料表 SELECT id, name, spec_model, product_image AS image_url, img_embedding AS vec FROM material_base WHERE img_embedding IS NOT NULL UNION ALL -- 2. 采购入库表 (通过 base_id 关联拿真实物料) SELECT mb.id, mb.name, mb.spec_model, sb.arrival_photo AS image_url, sb.arrival_image_embedding AS vec FROM stock_buy sb JOIN material_base mb ON sb.base_id = mb.id WHERE sb.arrival_image_embedding IS NOT NULL UNION ALL -- 3. 半成品入库表 SELECT mb.id, mb.name, mb.spec_model, ss.arrival_photo AS image_url, ss.arrival_image_embedding AS vec FROM stock_semi ss JOIN material_base mb ON ss.base_id = mb.id WHERE ss.arrival_image_embedding IS NOT NULL UNION ALL -- 4. 成品入库表 SELECT mb.id, mb.name, mb.spec_model, sp.product_photo AS image_url, sp.arrival_image_embedding AS vec FROM stock_product sp JOIN material_base mb ON sp.base_id = mb.id WHERE sp.arrival_image_embedding IS NOT NULL ) AS combined ORDER BY vec <=> :query_vector LIMIT 10 """) # 执行查询 records = db.session.execute(sql, {"query_vector": query_vector_str}).fetchall() results = [] seen_product_ids = set() # 【新增】用来记录已经添加过的物料 ID for row in records: # 【新增】如果这个物料已经在这个列表里了,直接跳过它 if row.id in seen_product_ids: continue # 记录这个物料 ID,保证下次不会再重复添加 seen_product_ids.add(row.id) # 1. 提取原始 URL raw_url = row.image_url clean_url = "" if raw_url: if raw_url.startswith('[') and raw_url.endswith(']'): import json try: url_list = json.loads(raw_url) clean_url = url_list[0] if url_list else "" except: clean_url = raw_url else: clean_url = raw_url # 2. 组装返回结果 results.append({ "product_id": row.id, "product_name": row.name, "spec_model": row.spec_model, "image_url": clean_url, "similarity": round(float(row.similarity), 4) }) # 【新增】只要凑够了 10 个完全不同的物料,就立刻结束循环 if len(results) >= 10: break return jsonify({"code": 200, "data": results}) except Exception as e: print(f"❌ [ImageSearch] 数据库检索失败: {e}") return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500