Files
KCGL/inventory-backend/app/api/v1/common/image_search.py

161 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索
"""
import os
import uuid
import json
from flask import Blueprint, request, jsonify
from sqlalchemy import text
from app.extensions import db
from app.utils.ai_vision import load_clip_model, get_image_embedding
# 注册蓝图
image_search_bp = Blueprint('image_search', __name__)
# ============================================================================
# POST /api/v1/common/image-search
# 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索
# ============================================================================
@image_search_bp.route('/image-search', methods=['POST'])
def image_search():
# ---------------------------------------------------------
# 1. 检查文件
# ---------------------------------------------------------
if 'file' not in request.files:
return jsonify({"code": 400, "msg": "未找到图片文件"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"code": 400, "msg": "未选择文件"}), 400
# ---------------------------------------------------------
# 2. 安全保存临时文件
# ---------------------------------------------------------
ext = file.filename.rsplit('.', 1)[-1].lower()
if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}:
return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400
tmp_filename = f"{uuid.uuid4().hex}.{ext}"
tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads')
os.makedirs(tmp_dir, exist_ok=True)
tmp_path = os.path.join(tmp_dir, tmp_filename)
try:
file.save(tmp_path)
print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}")
# ---------------------------------------------------------
# 3. 提取 CLIP embedding
# ---------------------------------------------------------
load_clip_model()
embedding = get_image_embedding(tmp_path)
print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}")
except Exception as e:
print(f"❌ [ImageSearch] 图像处理失败: {e}")
return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500
finally:
# ---------------------------------------------------------
# 4. 无论成功与否,都删除临时文件
# ---------------------------------------------------------
if os.path.exists(tmp_path):
try:
os.remove(tmp_path)
print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}")
except Exception as e:
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
# ---------------------------------------------------------
# 5. pgvector 余弦相似度检索(跨表联合检索)
# ---------------------------------------------------------
try:
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
sql = text("""
SELECT id, name, spec_model, image_url,
(1 - (vec <=> :query_vector)) AS similarity
FROM (
-- 1. 基础物料表
SELECT id, name, spec_model, product_image AS image_url, img_embedding AS vec
FROM material_base
WHERE img_embedding IS NOT NULL
UNION ALL
-- 2. 采购入库表 (通过 base_id 关联拿真实物料)
SELECT mb.id, mb.name, mb.spec_model, sb.arrival_photo AS image_url, sb.arrival_image_embedding AS vec
FROM stock_buy sb
JOIN material_base mb ON sb.base_id = mb.id
WHERE sb.arrival_image_embedding IS NOT NULL
UNION ALL
-- 3. 半成品入库表
SELECT mb.id, mb.name, mb.spec_model, ss.arrival_photo AS image_url, ss.arrival_image_embedding AS vec
FROM stock_semi ss
JOIN material_base mb ON ss.base_id = mb.id
WHERE ss.arrival_image_embedding IS NOT NULL
UNION ALL
-- 4. 成品入库表
SELECT mb.id, mb.name, mb.spec_model, sp.product_photo AS image_url, sp.arrival_image_embedding AS vec
FROM stock_product sp
JOIN material_base mb ON sp.base_id = mb.id
WHERE sp.arrival_image_embedding IS NOT NULL
) AS combined
ORDER BY vec <=> :query_vector LIMIT 10
""")
# 执行查询
records = db.session.execute(sql, {"query_vector": query_vector_str}).fetchall()
results = []
seen_product_ids = set() # 【新增】用来记录已经添加过的物料 ID
for row in records:
# 【新增】如果这个物料已经在这个列表里了,直接跳过它
if row.id in seen_product_ids:
continue
# 记录这个物料 ID保证下次不会再重复添加
seen_product_ids.add(row.id)
# 1. 提取原始 URL
raw_url = row.image_url
clean_url = ""
if raw_url:
if raw_url.startswith('[') and raw_url.endswith(']'):
import json
try:
url_list = json.loads(raw_url)
clean_url = url_list[0] if url_list else ""
except:
clean_url = raw_url
else:
clean_url = raw_url
# 2. 组装返回结果
results.append({
"product_id": row.id,
"product_name": row.name,
"spec_model": row.spec_model,
"image_url": clean_url,
"similarity": round(float(row.similarity), 4)
})
# 【新增】只要凑够了 10 个完全不同的物料,就立刻结束循环
if len(results) >= 10:
break
return jsonify({"code": 200, "data": results})
except Exception as e:
print(f"❌ [ImageSearch] 数据库检索失败: {e}")
return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500