Files
KCGL/inventory-backend/app/api/v1/common/image_search.py

299 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索
数据源image_embeddings 表(统一向量存储)
"""
import os
import uuid
import json
from flask import Blueprint, request, jsonify
from sqlalchemy import text
from app.extensions import db
from app.utils.ai_vision import load_clip_model, get_image_embedding
from app.models.inbound.buy import StockBuy
from app.models.inbound.semi import StockSemi
from app.models.inbound.product import StockProduct
from app.models.base import MaterialBase
# 注册蓝图
image_search_bp = Blueprint('image_search', __name__)
# ============================================================================
# 可配置参数
# ============================================================================
# 以图搜图相似度阈值:余弦距离必须小于此值(距离越小越相似)
# 即余弦相似度 = 1 - 距离,必须 > (1 - SIMILARITY_THRESHOLD)
# 默认 0.25 对应余弦相似度 > 0.75
SIMILARITY_DISTANCE_THRESHOLD = 0.40
# ============================================================================
# POST /api/v1/common/image-search
# 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索
# ============================================================================
@image_search_bp.route('/image-search', methods=['POST'])
def image_search():
# ---------------------------------------------------------
# 1. 检查文件
# ---------------------------------------------------------
if 'file' not in request.files:
return jsonify({"code": 400, "msg": "未找到图片文件"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"code": 400, "msg": "未选择文件"}), 400
# ---------------------------------------------------------
# 2. 安全保存临时文件
# ---------------------------------------------------------
ext = file.filename.rsplit('.', 1)[-1].lower()
if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}:
return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400
tmp_filename = f"{uuid.uuid4().hex}.{ext}"
tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads')
os.makedirs(tmp_dir, exist_ok=True)
tmp_path = os.path.join(tmp_dir, tmp_filename)
try:
file.save(tmp_path)
print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}")
# ---------------------------------------------------------
# 3. 提取 CLIP embedding
# ---------------------------------------------------------
load_clip_model()
embedding = get_image_embedding(tmp_path)
print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}")
except Exception as e:
print(f"❌ [ImageSearch] 图像处理失败: {e}")
return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500
finally:
# ---------------------------------------------------------
# 4. 无论成功与否,都删除临时文件
# ---------------------------------------------------------
if os.path.exists(tmp_path):
try:
os.remove(tmp_path)
print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}")
except Exception as e:
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
# ---------------------------------------------------------
# 5. pgvector 余弦相似度检索(统一查 image_embeddings 表)
# ---------------------------------------------------------
try:
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
sql = text("""
SELECT
ie.id AS embedding_id,
ie.module_name,
ie.target_id,
ie.image_url,
(1 - (ie.embedding <=> :query_vector)) AS similarity,
(ie.embedding <=> :query_vector) AS distance
FROM image_embeddings ie
WHERE ie.embedding IS NOT NULL
AND (ie.embedding <=> :query_vector) < :distance_threshold
ORDER BY ie.embedding <=> :query_vector
LIMIT 200
""")
raw_records = db.session.execute(sql, {
"query_vector": query_vector_str,
"distance_threshold": SIMILARITY_DISTANCE_THRESHOLD
}).fetchall()
if not raw_records:
return jsonify({"code": 200, "data": [], "msg": "未找到相似图片(阈值过滤后)"})
# ---------------------------------------------------------
# Step 1: 初步去重(同入库单只保留最相似的图片)
# ---------------------------------------------------------
first_img_seen = {}
unique_records = []
for row in raw_records:
key = (row.module_name, row.target_id)
if key not in first_img_seen:
first_img_seen[key] = True
unique_records.append(row)
# ---------------------------------------------------------
# Step 2: 按物料维度去重(相同物料只保留第一条 = 相似度最高的那条)
# ---------------------------------------------------------
target_ids_by_module = {}
for row in unique_records:
target_ids_by_module.setdefault(row.module_name, []).append(row.target_id)
# 查询每条记录的 base_id跨 stock_buy/semi/product/material_base
base_id_map = {}
for module in ('stock_buy', 'stock_semi', 'stock_product'):
if module not in target_ids_by_module:
continue
ids = target_ids_by_module[module]
ModelCls = StockBuy if module == 'stock_buy' else (StockSemi if module == 'stock_semi' else StockProduct)
id_col = getattr(ModelCls, 'id')
base_col = getattr(ModelCls, 'base_id')
rows = (
db.session.query(id_col, base_col)
.outerjoin(MaterialBase, base_col == MaterialBase.id)
.filter(id_col.in_(ids))
.all()
)
for rec_id, base_id in rows:
base_id_map[(module, rec_id)] = base_id
if 'material_base' in target_ids_by_module:
for rec_id in target_ids_by_module['material_base']:
base_id_map[('material_base', rec_id)] = rec_id
# 按 base_id 去重:相同物料只保留第一张图
material_seen = {}
final_records = []
for row in unique_records:
base_id = base_id_map.get((row.module_name, row.target_id))
if base_id is not None and base_id in material_seen:
continue
if base_id is not None:
material_seen[base_id] = True
final_records.append(row)
# ---------------------------------------------------------
# Step 3: 批量回填业务数据(基于去重后的 final_records
# ---------------------------------------------------------
target_ids_by_module = {}
for row in final_records:
target_ids_by_module.setdefault(row.module_name, []).append(row.target_id)
business_map = {}
# 回填 StockBuy
if 'stock_buy' in target_ids_by_module:
ids = target_ids_by_module['stock_buy']
records = (
db.session.query(StockBuy)
.filter(StockBuy.id.in_(ids))
.outerjoin(MaterialBase, StockBuy.base_id == MaterialBase.id)
.all()
)
for r in records:
business_map[('stock_buy', r.id)] = {
'record_id': r.id,
'name': r.base.name if r.base else None,
'spec_model': r.base.spec_model if r.base else None,
'sku': r.sku,
'barcode': r.barcode,
'serial_number': r.serial_number,
'batch_number': r.batch_number,
'status': r.status,
'warehouse_location': r.warehouse_location,
'stock_quantity': r.stock_quantity,
'module_name': 'stock_buy',
'url': '/inventory/buy',
}
# 回填 StockSemi
if 'stock_semi' in target_ids_by_module:
ids = target_ids_by_module['stock_semi']
records = (
db.session.query(StockSemi)
.filter(StockSemi.id.in_(ids))
.outerjoin(MaterialBase, StockSemi.base_id == MaterialBase.id)
.all()
)
for r in records:
business_map[('stock_semi', r.id)] = {
'record_id': r.id,
'name': r.base.name if r.base else None,
'spec_model': r.base.spec_model if r.base else None,
'sku': r.sku,
'barcode': r.barcode,
'serial_number': r.serial_number,
'batch_number': r.batch_number,
'status': r.status,
'warehouse_location': r.warehouse_location,
'stock_quantity': r.stock_quantity,
'module_name': 'stock_semi',
'url': '/inventory/semi',
}
# 回填 StockProduct
if 'stock_product' in target_ids_by_module:
ids = target_ids_by_module['stock_product']
records = (
db.session.query(StockProduct)
.filter(StockProduct.id.in_(ids))
.outerjoin(MaterialBase, StockProduct.base_id == MaterialBase.id)
.all()
)
for r in records:
business_map[('stock_product', r.id)] = {
'record_id': r.id,
'name': r.base.name if r.base else None,
'spec_model': r.base.spec_model if r.base else None,
'sku': r.sku,
'barcode': r.barcode,
'serial_number': r.serial_number,
'batch_number': r.batch_number,
'status': r.status,
'warehouse_location': r.warehouse_location,
'stock_quantity': r.stock_quantity,
'sale_price': r.sale_price,
'module_name': 'stock_product',
'url': '/inventory/product',
}
# 回填 MaterialBase
if 'material_base' in target_ids_by_module:
ids = target_ids_by_module['material_base']
records = MaterialBase.query.filter(MaterialBase.id.in_(ids)).all()
for r in records:
business_map[('material_base', r.id)] = {
'record_id': r.id,
'name': r.name,
'spec_model': r.spec_model,
'common_name': r.common_name,
'category': r.category,
'material_type': r.material_type,
'unit': r.unit,
'module_name': 'material_base',
'url': '/material/index',
}
# 组装最终返回(基于 final_records按相似度从高到低
results = []
for row in final_records:
key = (row.module_name, row.target_id)
biz = business_map.get(key, {})
raw_url = row.image_url or ''
clean_url = raw_url
if raw_url.startswith('['):
try:
url_list = json.loads(raw_url)
clean_url = url_list[0] if url_list else ''
except:
pass
results.append({
"module_name": row.module_name,
"target_id": row.target_id,
"image_url": clean_url,
"similarity": round(float(row.similarity), 4),
"product_name": biz.get('name') or biz.get('material_name') or '未命名物料',
"product_id": row.target_id,
"spec_model": biz.get('spec_model') or '',
"business_data": biz,
})
if len(results) >= 10:
break
return jsonify({"code": 200, "data": results})
except Exception as e:
print(f"❌ [ImageSearch] 数据库检索失败: {e}")
return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500