126 lines
4.7 KiB
Python
126 lines
4.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索
|
|
"""
|
|
|
|
import os
|
|
import uuid
|
|
import json
|
|
from flask import Blueprint, request, jsonify
|
|
from sqlalchemy import text
|
|
from app.extensions import db
|
|
from app.utils.ai_vision import load_clip_model, get_image_embedding
|
|
|
|
# 注册蓝图
|
|
image_search_bp = Blueprint('image_search', __name__)
|
|
|
|
|
|
# ============================================================================
|
|
# POST /api/v1/common/image-search
|
|
# 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索
|
|
# ============================================================================
|
|
|
|
@image_search_bp.route('/image-search', methods=['POST'])
|
|
def image_search():
|
|
# ---------------------------------------------------------
|
|
# 1. 检查文件
|
|
# ---------------------------------------------------------
|
|
if 'file' not in request.files:
|
|
return jsonify({"code": 400, "msg": "未找到图片文件"}), 400
|
|
|
|
file = request.files['file']
|
|
if file.filename == '':
|
|
return jsonify({"code": 400, "msg": "未选择文件"}), 400
|
|
|
|
# ---------------------------------------------------------
|
|
# 2. 安全保存临时文件
|
|
# ---------------------------------------------------------
|
|
ext = file.filename.rsplit('.', 1)[-1].lower()
|
|
if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}:
|
|
return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400
|
|
|
|
tmp_filename = f"{uuid.uuid4().hex}.{ext}"
|
|
tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads')
|
|
os.makedirs(tmp_dir, exist_ok=True)
|
|
tmp_path = os.path.join(tmp_dir, tmp_filename)
|
|
|
|
try:
|
|
file.save(tmp_path)
|
|
print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}")
|
|
|
|
# ---------------------------------------------------------
|
|
# 3. 提取 CLIP embedding
|
|
# ---------------------------------------------------------
|
|
load_clip_model()
|
|
embedding = get_image_embedding(tmp_path)
|
|
print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ [ImageSearch] 图像处理失败: {e}")
|
|
return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500
|
|
|
|
finally:
|
|
# ---------------------------------------------------------
|
|
# 4. 无论成功与否,都删除临时文件
|
|
# ---------------------------------------------------------
|
|
if os.path.exists(tmp_path):
|
|
try:
|
|
os.remove(tmp_path)
|
|
print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}")
|
|
except Exception as e:
|
|
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
|
|
|
|
# ---------------------------------------------------------
|
|
# 5. pgvector 余弦相似度检索
|
|
# ---------------------------------------------------------
|
|
try:
|
|
# 将 Python list 转为 PostgreSQL 向量格式: '[0.1, 0.2, ...]'
|
|
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
|
|
|
|
sql = text("""
|
|
SELECT id, name, spec_model, product_image,
|
|
(1 - (img_embedding <=> :query_vector)) AS similarity
|
|
FROM material_base
|
|
WHERE img_embedding IS NOT NULL
|
|
ORDER BY img_embedding <=> :query_vector
|
|
LIMIT 5
|
|
""")
|
|
|
|
result = db.session.execute(sql, {"query_vector": query_vector_str})
|
|
rows = result.fetchall()
|
|
|
|
results = []
|
|
for row in rows:
|
|
product_id = row[0]
|
|
product_name = row[1] or ""
|
|
spec_model = row[2] or ""
|
|
product_image = row[3]
|
|
|
|
# 解析图片 URL 列表,取第一张
|
|
image_url = ""
|
|
if product_image:
|
|
try:
|
|
image_list = json.loads(product_image)
|
|
if image_list and len(image_list) > 0:
|
|
image_url = image_list[0]
|
|
except Exception:
|
|
image_url = str(product_image)
|
|
|
|
results.append({
|
|
"product_id": product_id,
|
|
"product_name": product_name,
|
|
"spec_model": spec_model,
|
|
"image_url": image_url,
|
|
"similarity": round(float(row[4]), 4)
|
|
})
|
|
|
|
print(f"✅ [ImageSearch] 检索完成,命中 {len(results)} 条结果")
|
|
return jsonify({
|
|
"code": 200,
|
|
"msg": "检索成功",
|
|
"data": results
|
|
})
|
|
|
|
except Exception as e:
|
|
print(f"❌ [ImageSearch] 数据库检索失败: {e}")
|
|
return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500 |