161 lines
6.2 KiB
Python
161 lines
6.2 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索
|
||
"""
|
||
|
||
import os
|
||
import uuid
|
||
import json
|
||
from flask import Blueprint, request, jsonify
|
||
from sqlalchemy import text
|
||
from app.extensions import db
|
||
from app.utils.ai_vision import load_clip_model, get_image_embedding
|
||
|
||
# 注册蓝图
|
||
image_search_bp = Blueprint('image_search', __name__)
|
||
|
||
|
||
# ============================================================================
|
||
# POST /api/v1/common/image-search
|
||
# 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索
|
||
# ============================================================================
|
||
|
||
@image_search_bp.route('/image-search', methods=['POST'])
|
||
def image_search():
|
||
# ---------------------------------------------------------
|
||
# 1. 检查文件
|
||
# ---------------------------------------------------------
|
||
if 'file' not in request.files:
|
||
return jsonify({"code": 400, "msg": "未找到图片文件"}), 400
|
||
|
||
file = request.files['file']
|
||
if file.filename == '':
|
||
return jsonify({"code": 400, "msg": "未选择文件"}), 400
|
||
|
||
# ---------------------------------------------------------
|
||
# 2. 安全保存临时文件
|
||
# ---------------------------------------------------------
|
||
ext = file.filename.rsplit('.', 1)[-1].lower()
|
||
if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}:
|
||
return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400
|
||
|
||
tmp_filename = f"{uuid.uuid4().hex}.{ext}"
|
||
tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads')
|
||
os.makedirs(tmp_dir, exist_ok=True)
|
||
tmp_path = os.path.join(tmp_dir, tmp_filename)
|
||
|
||
try:
|
||
file.save(tmp_path)
|
||
print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}")
|
||
|
||
# ---------------------------------------------------------
|
||
# 3. 提取 CLIP embedding
|
||
# ---------------------------------------------------------
|
||
load_clip_model()
|
||
embedding = get_image_embedding(tmp_path)
|
||
print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ [ImageSearch] 图像处理失败: {e}")
|
||
return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500
|
||
|
||
finally:
|
||
# ---------------------------------------------------------
|
||
# 4. 无论成功与否,都删除临时文件
|
||
# ---------------------------------------------------------
|
||
if os.path.exists(tmp_path):
|
||
try:
|
||
os.remove(tmp_path)
|
||
print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}")
|
||
except Exception as e:
|
||
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
|
||
|
||
# ---------------------------------------------------------
|
||
# 5. pgvector 余弦相似度检索(跨表联合检索)
|
||
# ---------------------------------------------------------
|
||
try:
|
||
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
|
||
|
||
sql = text("""
|
||
SELECT id, name, spec_model, image_url,
|
||
(1 - (vec <=> :query_vector)) AS similarity
|
||
FROM (
|
||
-- 1. 基础物料表
|
||
SELECT id, name, spec_model, product_image AS image_url, img_embedding AS vec
|
||
FROM material_base
|
||
WHERE img_embedding IS NOT NULL
|
||
|
||
UNION ALL
|
||
|
||
-- 2. 采购入库表 (通过 base_id 关联拿真实物料)
|
||
SELECT mb.id, mb.name, mb.spec_model, sb.arrival_photo AS image_url, sb.arrival_image_embedding AS vec
|
||
FROM stock_buy sb
|
||
JOIN material_base mb ON sb.base_id = mb.id
|
||
WHERE sb.arrival_image_embedding IS NOT NULL
|
||
|
||
UNION ALL
|
||
|
||
-- 3. 半成品入库表
|
||
SELECT mb.id, mb.name, mb.spec_model, ss.arrival_photo AS image_url, ss.arrival_image_embedding AS vec
|
||
FROM stock_semi ss
|
||
JOIN material_base mb ON ss.base_id = mb.id
|
||
WHERE ss.arrival_image_embedding IS NOT NULL
|
||
|
||
UNION ALL
|
||
|
||
-- 4. 成品入库表
|
||
SELECT mb.id, mb.name, mb.spec_model, sp.product_photo AS image_url, sp.arrival_image_embedding AS vec
|
||
FROM stock_product sp
|
||
JOIN material_base mb ON sp.base_id = mb.id
|
||
WHERE sp.arrival_image_embedding IS NOT NULL
|
||
) AS combined
|
||
ORDER BY vec <=> :query_vector LIMIT 10
|
||
""")
|
||
|
||
# 执行查询
|
||
records = db.session.execute(sql, {"query_vector": query_vector_str}).fetchall()
|
||
|
||
results = []
|
||
seen_product_ids = set() # 【新增】用来记录已经添加过的物料 ID
|
||
|
||
for row in records:
|
||
# 【新增】如果这个物料已经在这个列表里了,直接跳过它
|
||
if row.id in seen_product_ids:
|
||
continue
|
||
|
||
# 记录这个物料 ID,保证下次不会再重复添加
|
||
seen_product_ids.add(row.id)
|
||
|
||
# 1. 提取原始 URL
|
||
raw_url = row.image_url
|
||
clean_url = ""
|
||
|
||
if raw_url:
|
||
if raw_url.startswith('[') and raw_url.endswith(']'):
|
||
import json
|
||
try:
|
||
url_list = json.loads(raw_url)
|
||
clean_url = url_list[0] if url_list else ""
|
||
except:
|
||
clean_url = raw_url
|
||
else:
|
||
clean_url = raw_url
|
||
|
||
# 2. 组装返回结果
|
||
results.append({
|
||
"product_id": row.id,
|
||
"product_name": row.name,
|
||
"spec_model": row.spec_model,
|
||
"image_url": clean_url,
|
||
"similarity": round(float(row.similarity), 4)
|
||
})
|
||
|
||
# 【新增】只要凑够了 10 个完全不同的物料,就立刻结束循环
|
||
if len(results) >= 10:
|
||
break
|
||
|
||
return jsonify({"code": 200, "data": results})
|
||
|
||
except Exception as e:
|
||
print(f"❌ [ImageSearch] 数据库检索失败: {e}")
|
||
return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500 |