Files
KCGL/inventory-backend/app/api/v1/common/image_search.py

238 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索
数据源image_embeddings 表(统一向量存储)
"""
import os
import uuid
import json
from flask import Blueprint, request, jsonify
from sqlalchemy import text
from app.extensions import db
from app.utils.ai_vision import load_clip_model, get_image_embedding
from app.models.inbound.buy import StockBuy
from app.models.inbound.semi import StockSemi
from app.models.inbound.product import StockProduct
from app.models.base import MaterialBase
# 注册蓝图
image_search_bp = Blueprint('image_search', __name__)
# ============================================================================
# POST /api/v1/common/image-search
# 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索
# ============================================================================
@image_search_bp.route('/image-search', methods=['POST'])
def image_search():
# ---------------------------------------------------------
# 1. 检查文件
# ---------------------------------------------------------
if 'file' not in request.files:
return jsonify({"code": 400, "msg": "未找到图片文件"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"code": 400, "msg": "未选择文件"}), 400
# ---------------------------------------------------------
# 2. 安全保存临时文件
# ---------------------------------------------------------
ext = file.filename.rsplit('.', 1)[-1].lower()
if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}:
return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400
tmp_filename = f"{uuid.uuid4().hex}.{ext}"
tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads')
os.makedirs(tmp_dir, exist_ok=True)
tmp_path = os.path.join(tmp_dir, tmp_filename)
try:
file.save(tmp_path)
print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}")
# ---------------------------------------------------------
# 3. 提取 CLIP embedding
# ---------------------------------------------------------
load_clip_model()
embedding = get_image_embedding(tmp_path)
print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}")
except Exception as e:
print(f"❌ [ImageSearch] 图像处理失败: {e}")
return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500
finally:
# ---------------------------------------------------------
# 4. 无论成功与否,都删除临时文件
# ---------------------------------------------------------
if os.path.exists(tmp_path):
try:
os.remove(tmp_path)
print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}")
except Exception as e:
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
# ---------------------------------------------------------
# 5. pgvector 余弦相似度检索(统一查 image_embeddings 表)
# ---------------------------------------------------------
try:
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
sql = text("""
SELECT
ie.id AS embedding_id,
ie.module_name,
ie.target_id,
ie.image_url,
(1 - (ie.embedding <=> :query_vector)) AS similarity
FROM image_embeddings ie
WHERE ie.embedding IS NOT NULL
ORDER BY ie.embedding <=> :query_vector
LIMIT 200
""")
raw_records = db.session.execute(sql, {"query_vector": query_vector_str}).fetchall()
if not raw_records:
return jsonify({"code": 200, "data": []})
# 按 (module_name, target_id) 去重,每业务记录只保留最相似的那张图
seen = {}
for row in raw_records:
key = (row.module_name, row.target_id)
if key not in seen:
seen[key] = row
# 批量回填业务数据
target_ids_by_module = {}
for row in seen.values():
target_ids_by_module.setdefault(row.module_name, []).append(row.target_id)
business_map = {}
# 回填 StockBuy
if 'stock_buy' in target_ids_by_module:
ids = target_ids_by_module['stock_buy']
records = (
db.session.query(StockBuy)
.filter(StockBuy.id.in_(ids))
.outerjoin(MaterialBase, StockBuy.base_id == MaterialBase.id)
.all()
)
for r in records:
business_map[('stock_buy', r.id)] = {
'record_id': r.id,
'name': r.base.name if r.base else None,
'spec_model': r.base.spec_model if r.base else None,
'sku': r.sku,
'barcode': r.barcode,
'serial_number': r.serial_number,
'batch_number': r.batch_number,
'status': r.status,
'warehouse_location': r.warehouse_location,
'stock_quantity': r.stock_quantity,
'module_name': 'stock_buy',
'url': '/inventory/buy',
}
# 回填 StockSemi
if 'stock_semi' in target_ids_by_module:
ids = target_ids_by_module['stock_semi']
records = (
db.session.query(StockSemi)
.filter(StockSemi.id.in_(ids))
.outerjoin(MaterialBase, StockSemi.base_id == MaterialBase.id)
.all()
)
for r in records:
business_map[('stock_semi', r.id)] = {
'record_id': r.id,
'name': r.base.name if r.base else None,
'spec_model': r.base.spec_model if r.base else None,
'sku': r.sku,
'barcode': r.barcode,
'serial_number': r.serial_number,
'batch_number': r.batch_number,
'status': r.status,
'warehouse_location': r.warehouse_location,
'stock_quantity': r.stock_quantity,
'module_name': 'stock_semi',
'url': '/inventory/semi',
}
# 回填 StockProduct
if 'stock_product' in target_ids_by_module:
ids = target_ids_by_module['stock_product']
records = (
db.session.query(StockProduct)
.filter(StockProduct.id.in_(ids))
.outerjoin(MaterialBase, StockProduct.base_id == MaterialBase.id)
.all()
)
for r in records:
business_map[('stock_product', r.id)] = {
'record_id': r.id,
'name': r.base.name if r.base else None,
'spec_model': r.base.spec_model if r.base else None,
'sku': r.sku,
'barcode': r.barcode,
'serial_number': r.serial_number,
'batch_number': r.batch_number,
'status': r.status,
'warehouse_location': r.warehouse_location,
'stock_quantity': r.stock_quantity,
'sale_price': r.sale_price,
'module_name': 'stock_product',
'url': '/inventory/product',
}
# 回填 MaterialBase
if 'material_base' in target_ids_by_module:
ids = target_ids_by_module['material_base']
records = MaterialBase.query.filter(MaterialBase.id.in_(ids)).all()
for r in records:
business_map[('material_base', r.id)] = {
'record_id': r.id,
'name': r.name,
'spec_model': r.spec_model,
'common_name': r.common_name,
'category': r.category,
'material_type': r.material_type,
'unit': r.unit,
'module_name': 'material_base',
'url': '/material/index',
}
# 组装最终返回
results = []
for row in seen.values():
key = (row.module_name, row.target_id)
biz = business_map.get(key, {})
raw_url = row.image_url or ''
clean_url = raw_url
if raw_url.startswith('['):
try:
url_list = json.loads(raw_url)
clean_url = url_list[0] if url_list else ''
except:
pass
results.append({
"module_name": row.module_name,
"target_id": row.target_id,
"image_url": clean_url,
"similarity": round(float(row.similarity), 4),
"product_name": biz.get('name') or biz.get('material_name') or '未命名物料',
"product_id": row.target_id,
"spec_model": biz.get('spec_model') or '',
"business_data": biz,
})
if len(results) >= 10:
break
return jsonify({"code": 200, "data": results})
except Exception as e:
print(f"❌ [ImageSearch] 数据库检索失败: {e}")
return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500