feat: 以图搜图功能升级(跨表UNION检索 + 拍照识图入口 + 批量向量初始化脚本)
This commit is contained in:
@ -71,19 +71,45 @@ def image_search():
|
||||
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 5. pgvector 余弦相似度检索
|
||||
# 5. pgvector 余弦相似度检索(跨表联合检索)
|
||||
# ---------------------------------------------------------
|
||||
try:
|
||||
# 将 Python list 转为 PostgreSQL 向量格式: '[0.1, 0.2, ...]'
|
||||
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
|
||||
|
||||
sql = text("""
|
||||
SELECT id, name, spec_model, product_image,
|
||||
(1 - (img_embedding <=> :query_vector)) AS similarity
|
||||
FROM material_base
|
||||
WHERE img_embedding IS NOT NULL
|
||||
ORDER BY img_embedding <=> :query_vector
|
||||
LIMIT 5
|
||||
SELECT id, name, spec_model, image_url,
|
||||
(1 - (vec <=> :query_vector)) AS similarity
|
||||
FROM (
|
||||
SELECT id,
|
||||
COALESCE(name, '') AS name,
|
||||
COALESCE(spec, '') AS spec_model,
|
||||
COALESCE(product_image, '') AS image_url,
|
||||
img_embedding AS vec
|
||||
FROM material_base
|
||||
WHERE img_embedding IS NOT NULL
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT id,
|
||||
'采购入库' AS name,
|
||||
'到货照片' AS spec_model,
|
||||
COALESCE(arrival_photo, '') AS image_url,
|
||||
arrival_image_embedding AS vec
|
||||
FROM stock_buy
|
||||
WHERE arrival_image_embedding IS NOT NULL
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT id,
|
||||
'采购入库' AS name,
|
||||
'质检报告' AS spec_model,
|
||||
COALESCE(qc_report, '') AS image_url,
|
||||
qc_report_image_embedding AS vec
|
||||
FROM stock_buy
|
||||
WHERE qc_report_image_embedding IS NOT NULL
|
||||
) AS combined
|
||||
ORDER BY vec <=> :query_vector
|
||||
LIMIT 10
|
||||
""")
|
||||
|
||||
result = db.session.execute(sql, {"query_vector": query_vector_str})
|
||||
@ -91,30 +117,31 @@ def image_search():
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
product_id = row[0]
|
||||
product_name = row[1] or ""
|
||||
item_id = row[0]
|
||||
item_name = row[1] or ""
|
||||
spec_model = row[2] or ""
|
||||
product_image = row[3]
|
||||
raw_image = row[3]
|
||||
|
||||
# 解析图片 URL 列表,取第一张
|
||||
image_url = ""
|
||||
if product_image:
|
||||
if raw_image:
|
||||
try:
|
||||
image_list = json.loads(product_image)
|
||||
image_list = json.loads(raw_image)
|
||||
if image_list and len(image_list) > 0:
|
||||
image_url = image_list[0]
|
||||
except Exception:
|
||||
image_url = str(product_image)
|
||||
# 纯字符串直接使用
|
||||
image_url = str(raw_image)
|
||||
|
||||
results.append({
|
||||
"product_id": product_id,
|
||||
"product_name": product_name,
|
||||
"id": item_id,
|
||||
"name": item_name,
|
||||
"spec_model": spec_model,
|
||||
"image_url": image_url,
|
||||
"similarity": round(float(row[4]), 4)
|
||||
})
|
||||
|
||||
print(f"✅ [ImageSearch] 检索完成,命中 {len(results)} 条结果")
|
||||
print(f"✅ [ImageSearch] 跨表检索完成,命中 {len(results)} 条结果")
|
||||
return jsonify({
|
||||
"code": 200,
|
||||
"msg": "检索成功",
|
||||
|
||||
@ -100,7 +100,7 @@ def get_image_embedding(image_path: str) -> list:
|
||||
提取图像的 512 维 CLIP embedding 向量
|
||||
|
||||
参数:
|
||||
image_path: 图像文件路径(支持本地路径或 URL)
|
||||
image_path: 图像文件路径
|
||||
|
||||
返回:
|
||||
list: 512 维浮点向量
|
||||
@ -108,25 +108,25 @@ def get_image_embedding(image_path: str) -> list:
|
||||
if ort_session is None:
|
||||
load_clip_model()
|
||||
|
||||
# 加载图像
|
||||
try:
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
except Exception as e:
|
||||
raise ValueError(f"图像加载失败: {image_path}, 错误: {e}")
|
||||
|
||||
# 中心裁剪
|
||||
# 1. 图片预处理
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = _center_crop_and_resize(image)
|
||||
|
||||
# 归一化
|
||||
input_data = _normalize(np.array(image))
|
||||
input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224]
|
||||
|
||||
# 添加 batch 维度: (C, H, W) -> (1, C, H, W)
|
||||
input_data = np.expand_dims(input_data, axis=0)
|
||||
# 2. 构造占位符输入 (关键修复)
|
||||
dummy_ids = np.zeros((1, 77), dtype=np.int64)
|
||||
dummy_mask = np.zeros((1, 77), dtype=np.int64)
|
||||
|
||||
# 推理
|
||||
outputs = ort_session.run(None, {'images': input_data.astype(np.float32)})
|
||||
|
||||
# 输出通常是 (1, 512) 的向量,取第一项并展平为 list
|
||||
embedding = outputs[0][0].tolist()
|
||||
|
||||
return embedding
|
||||
# 3. 传入模型进行推理
|
||||
# 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask'
|
||||
# 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认
|
||||
outputs = ort_session.run(
|
||||
['image_embeds'],
|
||||
{
|
||||
'input_ids': dummy_ids,
|
||||
'pixel_values': input_data.astype(np.float32),
|
||||
'attention_mask': dummy_mask
|
||||
}
|
||||
)
|
||||
return outputs[0][0].tolist()
|
||||
Reference in New Issue
Block a user