238 lines
9.4 KiB
Python
238 lines
9.4 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索
|
||
数据源:image_embeddings 表(统一向量存储)
|
||
"""
|
||
|
||
import os
|
||
import uuid
|
||
import json
|
||
from flask import Blueprint, request, jsonify
|
||
from sqlalchemy import text
|
||
from app.extensions import db
|
||
from app.utils.ai_vision import load_clip_model, get_image_embedding
|
||
from app.models.inbound.buy import StockBuy
|
||
from app.models.inbound.semi import StockSemi
|
||
from app.models.inbound.product import StockProduct
|
||
from app.models.base import MaterialBase
|
||
|
||
# 注册蓝图
|
||
image_search_bp = Blueprint('image_search', __name__)
|
||
|
||
|
||
# ============================================================================
|
||
# POST /api/v1/common/image-search
|
||
# 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索
|
||
# ============================================================================
|
||
|
||
@image_search_bp.route('/image-search', methods=['POST'])
|
||
def image_search():
|
||
# ---------------------------------------------------------
|
||
# 1. 检查文件
|
||
# ---------------------------------------------------------
|
||
if 'file' not in request.files:
|
||
return jsonify({"code": 400, "msg": "未找到图片文件"}), 400
|
||
|
||
file = request.files['file']
|
||
if file.filename == '':
|
||
return jsonify({"code": 400, "msg": "未选择文件"}), 400
|
||
|
||
# ---------------------------------------------------------
|
||
# 2. 安全保存临时文件
|
||
# ---------------------------------------------------------
|
||
ext = file.filename.rsplit('.', 1)[-1].lower()
|
||
if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}:
|
||
return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400
|
||
|
||
tmp_filename = f"{uuid.uuid4().hex}.{ext}"
|
||
tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads')
|
||
os.makedirs(tmp_dir, exist_ok=True)
|
||
tmp_path = os.path.join(tmp_dir, tmp_filename)
|
||
|
||
try:
|
||
file.save(tmp_path)
|
||
print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}")
|
||
|
||
# ---------------------------------------------------------
|
||
# 3. 提取 CLIP embedding
|
||
# ---------------------------------------------------------
|
||
load_clip_model()
|
||
embedding = get_image_embedding(tmp_path)
|
||
print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ [ImageSearch] 图像处理失败: {e}")
|
||
return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500
|
||
|
||
finally:
|
||
# ---------------------------------------------------------
|
||
# 4. 无论成功与否,都删除临时文件
|
||
# ---------------------------------------------------------
|
||
if os.path.exists(tmp_path):
|
||
try:
|
||
os.remove(tmp_path)
|
||
print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}")
|
||
except Exception as e:
|
||
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
|
||
|
||
# ---------------------------------------------------------
|
||
# 5. pgvector 余弦相似度检索(统一查 image_embeddings 表)
|
||
# ---------------------------------------------------------
|
||
try:
|
||
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
|
||
|
||
sql = text("""
|
||
SELECT
|
||
ie.id AS embedding_id,
|
||
ie.module_name,
|
||
ie.target_id,
|
||
ie.image_url,
|
||
(1 - (ie.embedding <=> :query_vector)) AS similarity
|
||
FROM image_embeddings ie
|
||
WHERE ie.embedding IS NOT NULL
|
||
ORDER BY ie.embedding <=> :query_vector
|
||
LIMIT 200
|
||
""")
|
||
|
||
raw_records = db.session.execute(sql, {"query_vector": query_vector_str}).fetchall()
|
||
if not raw_records:
|
||
return jsonify({"code": 200, "data": []})
|
||
|
||
# 按 (module_name, target_id) 去重,每业务记录只保留最相似的那张图
|
||
seen = {}
|
||
for row in raw_records:
|
||
key = (row.module_name, row.target_id)
|
||
if key not in seen:
|
||
seen[key] = row
|
||
|
||
# 批量回填业务数据
|
||
target_ids_by_module = {}
|
||
for row in seen.values():
|
||
target_ids_by_module.setdefault(row.module_name, []).append(row.target_id)
|
||
|
||
business_map = {}
|
||
|
||
# 回填 StockBuy
|
||
if 'stock_buy' in target_ids_by_module:
|
||
ids = target_ids_by_module['stock_buy']
|
||
records = (
|
||
db.session.query(StockBuy)
|
||
.filter(StockBuy.id.in_(ids))
|
||
.outerjoin(MaterialBase, StockBuy.base_id == MaterialBase.id)
|
||
.all()
|
||
)
|
||
for r in records:
|
||
business_map[('stock_buy', r.id)] = {
|
||
'record_id': r.id,
|
||
'name': r.base.name if r.base else None,
|
||
'spec_model': r.base.spec_model if r.base else None,
|
||
'sku': r.sku,
|
||
'barcode': r.barcode,
|
||
'serial_number': r.serial_number,
|
||
'batch_number': r.batch_number,
|
||
'status': r.status,
|
||
'warehouse_location': r.warehouse_location,
|
||
'stock_quantity': r.stock_quantity,
|
||
'module_name': 'stock_buy',
|
||
'url': '/inventory/buy',
|
||
}
|
||
|
||
# 回填 StockSemi
|
||
if 'stock_semi' in target_ids_by_module:
|
||
ids = target_ids_by_module['stock_semi']
|
||
records = (
|
||
db.session.query(StockSemi)
|
||
.filter(StockSemi.id.in_(ids))
|
||
.outerjoin(MaterialBase, StockSemi.base_id == MaterialBase.id)
|
||
.all()
|
||
)
|
||
for r in records:
|
||
business_map[('stock_semi', r.id)] = {
|
||
'record_id': r.id,
|
||
'name': r.base.name if r.base else None,
|
||
'spec_model': r.base.spec_model if r.base else None,
|
||
'sku': r.sku,
|
||
'barcode': r.barcode,
|
||
'serial_number': r.serial_number,
|
||
'batch_number': r.batch_number,
|
||
'status': r.status,
|
||
'warehouse_location': r.warehouse_location,
|
||
'stock_quantity': r.stock_quantity,
|
||
'module_name': 'stock_semi',
|
||
'url': '/inventory/semi',
|
||
}
|
||
|
||
# 回填 StockProduct
|
||
if 'stock_product' in target_ids_by_module:
|
||
ids = target_ids_by_module['stock_product']
|
||
records = (
|
||
db.session.query(StockProduct)
|
||
.filter(StockProduct.id.in_(ids))
|
||
.outerjoin(MaterialBase, StockProduct.base_id == MaterialBase.id)
|
||
.all()
|
||
)
|
||
for r in records:
|
||
business_map[('stock_product', r.id)] = {
|
||
'record_id': r.id,
|
||
'name': r.base.name if r.base else None,
|
||
'spec_model': r.base.spec_model if r.base else None,
|
||
'sku': r.sku,
|
||
'barcode': r.barcode,
|
||
'serial_number': r.serial_number,
|
||
'batch_number': r.batch_number,
|
||
'status': r.status,
|
||
'warehouse_location': r.warehouse_location,
|
||
'stock_quantity': r.stock_quantity,
|
||
'sale_price': r.sale_price,
|
||
'module_name': 'stock_product',
|
||
'url': '/inventory/product',
|
||
}
|
||
|
||
# 回填 MaterialBase
|
||
if 'material_base' in target_ids_by_module:
|
||
ids = target_ids_by_module['material_base']
|
||
records = MaterialBase.query.filter(MaterialBase.id.in_(ids)).all()
|
||
for r in records:
|
||
business_map[('material_base', r.id)] = {
|
||
'record_id': r.id,
|
||
'name': r.name,
|
||
'spec_model': r.spec_model,
|
||
'common_name': r.common_name,
|
||
'category': r.category,
|
||
'material_type': r.material_type,
|
||
'unit': r.unit,
|
||
'module_name': 'material_base',
|
||
'url': '/material/index',
|
||
}
|
||
|
||
# 组装最终返回
|
||
results = []
|
||
for row in seen.values():
|
||
key = (row.module_name, row.target_id)
|
||
biz = business_map.get(key, {})
|
||
raw_url = row.image_url or ''
|
||
clean_url = raw_url
|
||
if raw_url.startswith('['):
|
||
try:
|
||
url_list = json.loads(raw_url)
|
||
clean_url = url_list[0] if url_list else ''
|
||
except:
|
||
pass
|
||
results.append({
|
||
"module_name": row.module_name,
|
||
"target_id": row.target_id,
|
||
"image_url": clean_url,
|
||
"similarity": round(float(row.similarity), 4),
|
||
"product_name": biz.get('name') or biz.get('material_name') or '未命名物料',
|
||
"product_id": row.target_id,
|
||
"spec_model": biz.get('spec_model') or '',
|
||
"business_data": biz,
|
||
})
|
||
if len(results) >= 10:
|
||
break
|
||
|
||
return jsonify({"code": 200, "data": results})
|
||
|
||
except Exception as e:
|
||
print(f"❌ [ImageSearch] 数据库检索失败: {e}")
|
||
return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500 |