feat: 以图搜图功能升级(跨表UNION检索 + 拍照识图入口 + 批量向量初始化脚本)
This commit is contained in:
@ -71,19 +71,45 @@ def image_search():
|
||||
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 5. pgvector 余弦相似度检索
|
||||
# 5. pgvector 余弦相似度检索(跨表联合检索)
|
||||
# ---------------------------------------------------------
|
||||
try:
|
||||
# 将 Python list 转为 PostgreSQL 向量格式: '[0.1, 0.2, ...]'
|
||||
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
|
||||
|
||||
sql = text("""
|
||||
SELECT id, name, spec_model, product_image,
|
||||
(1 - (img_embedding <=> :query_vector)) AS similarity
|
||||
FROM material_base
|
||||
WHERE img_embedding IS NOT NULL
|
||||
ORDER BY img_embedding <=> :query_vector
|
||||
LIMIT 5
|
||||
SELECT id, name, spec_model, image_url,
|
||||
(1 - (vec <=> :query_vector)) AS similarity
|
||||
FROM (
|
||||
SELECT id,
|
||||
COALESCE(name, '') AS name,
|
||||
COALESCE(spec, '') AS spec_model,
|
||||
COALESCE(product_image, '') AS image_url,
|
||||
img_embedding AS vec
|
||||
FROM material_base
|
||||
WHERE img_embedding IS NOT NULL
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT id,
|
||||
'采购入库' AS name,
|
||||
'到货照片' AS spec_model,
|
||||
COALESCE(arrival_photo, '') AS image_url,
|
||||
arrival_image_embedding AS vec
|
||||
FROM stock_buy
|
||||
WHERE arrival_image_embedding IS NOT NULL
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT id,
|
||||
'采购入库' AS name,
|
||||
'质检报告' AS spec_model,
|
||||
COALESCE(qc_report, '') AS image_url,
|
||||
qc_report_image_embedding AS vec
|
||||
FROM stock_buy
|
||||
WHERE qc_report_image_embedding IS NOT NULL
|
||||
) AS combined
|
||||
ORDER BY vec <=> :query_vector
|
||||
LIMIT 10
|
||||
""")
|
||||
|
||||
result = db.session.execute(sql, {"query_vector": query_vector_str})
|
||||
@ -91,30 +117,31 @@ def image_search():
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
product_id = row[0]
|
||||
product_name = row[1] or ""
|
||||
item_id = row[0]
|
||||
item_name = row[1] or ""
|
||||
spec_model = row[2] or ""
|
||||
product_image = row[3]
|
||||
raw_image = row[3]
|
||||
|
||||
# 解析图片 URL 列表,取第一张
|
||||
image_url = ""
|
||||
if product_image:
|
||||
if raw_image:
|
||||
try:
|
||||
image_list = json.loads(product_image)
|
||||
image_list = json.loads(raw_image)
|
||||
if image_list and len(image_list) > 0:
|
||||
image_url = image_list[0]
|
||||
except Exception:
|
||||
image_url = str(product_image)
|
||||
# 纯字符串直接使用
|
||||
image_url = str(raw_image)
|
||||
|
||||
results.append({
|
||||
"product_id": product_id,
|
||||
"product_name": product_name,
|
||||
"id": item_id,
|
||||
"name": item_name,
|
||||
"spec_model": spec_model,
|
||||
"image_url": image_url,
|
||||
"similarity": round(float(row[4]), 4)
|
||||
})
|
||||
|
||||
print(f"✅ [ImageSearch] 检索完成,命中 {len(results)} 条结果")
|
||||
print(f"✅ [ImageSearch] 跨表检索完成,命中 {len(results)} 条结果")
|
||||
return jsonify({
|
||||
"code": 200,
|
||||
"msg": "检索成功",
|
||||
|
||||
@ -100,7 +100,7 @@ def get_image_embedding(image_path: str) -> list:
|
||||
提取图像的 512 维 CLIP embedding 向量
|
||||
|
||||
参数:
|
||||
image_path: 图像文件路径(支持本地路径或 URL)
|
||||
image_path: 图像文件路径
|
||||
|
||||
返回:
|
||||
list: 512 维浮点向量
|
||||
@ -108,25 +108,25 @@ def get_image_embedding(image_path: str) -> list:
|
||||
if ort_session is None:
|
||||
load_clip_model()
|
||||
|
||||
# 加载图像
|
||||
try:
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
except Exception as e:
|
||||
raise ValueError(f"图像加载失败: {image_path}, 错误: {e}")
|
||||
|
||||
# 中心裁剪
|
||||
# 1. 图片预处理
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = _center_crop_and_resize(image)
|
||||
|
||||
# 归一化
|
||||
input_data = _normalize(np.array(image))
|
||||
input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224]
|
||||
|
||||
# 添加 batch 维度: (C, H, W) -> (1, C, H, W)
|
||||
input_data = np.expand_dims(input_data, axis=0)
|
||||
# 2. 构造占位符输入 (关键修复)
|
||||
dummy_ids = np.zeros((1, 77), dtype=np.int64)
|
||||
dummy_mask = np.zeros((1, 77), dtype=np.int64)
|
||||
|
||||
# 推理
|
||||
outputs = ort_session.run(None, {'images': input_data.astype(np.float32)})
|
||||
|
||||
# 输出通常是 (1, 512) 的向量,取第一项并展平为 list
|
||||
embedding = outputs[0][0].tolist()
|
||||
|
||||
return embedding
|
||||
# 3. 传入模型进行推理
|
||||
# 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask'
|
||||
# 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认
|
||||
outputs = ort_session.run(
|
||||
['image_embeds'],
|
||||
{
|
||||
'input_ids': dummy_ids,
|
||||
'pixel_values': input_data.astype(np.float32),
|
||||
'attention_mask': dummy_mask
|
||||
}
|
||||
)
|
||||
return outputs[0][0].tolist()
|
||||
220
inventory-backend/scripts/init_all_vectors.py
Normal file
220
inventory-backend/scripts/init_all_vectors.py
Normal file
@ -0,0 +1,220 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
全量历史图片向量初始化脚本
|
||||
|
||||
功能:遍历配置表中所有历史图片字段,批量提取 CLIP 512 维向量并存回数据库。
|
||||
用法:python scripts/init_all_vectors.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
# 将项目根目录加入 Python 路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import text
|
||||
|
||||
# Flask 应用环境
|
||||
from app import create_app
|
||||
from app.extensions import db
|
||||
from app.utils.ai_vision import get_image_embedding, load_clip_model
|
||||
|
||||
# ============================================================================
|
||||
# 业务配置:表 → 图片字段 → 向量字段 映射
|
||||
# ============================================================================
|
||||
TARGET_TABLES = [
|
||||
# 基础物料
|
||||
{"table": "material_base", "img_col": "product_image", "vec_col": "img_embedding"},
|
||||
|
||||
# 采购入库
|
||||
{"table": "stock_buy", "img_col": "arrival_photo", "vec_col": "arrival_image_embedding"},
|
||||
{"table": "stock_buy", "img_col": "qc_report", "vec_col": "qc_report_image_embedding"},
|
||||
]
|
||||
|
||||
# 物理图片根目录(相对于 app 目录的相对路径 ../uploads/)
|
||||
APP_DIR = os.path.join(os.path.dirname(__file__), '..', 'app')
|
||||
UPLOADS_ROOT = os.path.abspath(os.path.join(APP_DIR, '..', 'uploads'))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 核心工具函数
|
||||
# ============================================================================
|
||||
|
||||
def parse_img_field(raw_value: str) -> List[str]:
|
||||
"""
|
||||
健壮解析图片字段,支持以下格式:
|
||||
- JSON 数组字符串: ["a.jpg", "b.jpg"]
|
||||
- 纯字符串单图片: "a.jpg"
|
||||
- 带 /api/v1/files/ 前缀: ["/api/v1/files/a.jpg"]
|
||||
返回: 提取出的文件名列表
|
||||
"""
|
||||
if not raw_value or (isinstance(raw_value, str) and not raw_value.strip()):
|
||||
return []
|
||||
|
||||
try:
|
||||
# 先尝试按 JSON 解析(处理 JSON 数组字符串)
|
||||
parsed = json.loads(raw_value)
|
||||
if isinstance(parsed, list):
|
||||
items = parsed
|
||||
else:
|
||||
items = [parsed]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# JSON 解析失败,说明是纯字符串,直接按单图片处理
|
||||
items = [raw_value.strip()]
|
||||
|
||||
filenames = []
|
||||
for item in items:
|
||||
if not item or not isinstance(item, str):
|
||||
continue
|
||||
item = item.strip()
|
||||
if not item:
|
||||
continue
|
||||
# 去掉可能的 /api/v1/files/ 前缀
|
||||
filename = os.path.basename(item)
|
||||
filenames.append(filename)
|
||||
|
||||
return filenames
|
||||
|
||||
|
||||
def build_local_path(filename: str) -> str:
|
||||
"""
|
||||
将文件名拼装成本地绝对路径
|
||||
"""
|
||||
return os.path.join(UPLOADS_ROOT, filename)
|
||||
|
||||
|
||||
def extract_first_valid_vector(raw_img_field: str, table_name: str, img_col: str) -> Optional[str]:
|
||||
"""
|
||||
读取图片字段,从第一条有效图片提取向量,返回写入 DB 的 JSON 字符串。
|
||||
如果所有图片均失败,返回 None。
|
||||
"""
|
||||
filenames = parse_img_field(raw_img_field)
|
||||
if not filenames:
|
||||
return None
|
||||
|
||||
for filename in filenames:
|
||||
local_path = build_local_path(filename)
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
print(f"\033[91m[WARN] {table_name}.{img_col} | 文件不存在: {local_path}\033[0m")
|
||||
continue
|
||||
|
||||
try:
|
||||
vec = get_image_embedding(local_path)
|
||||
if vec is not None:
|
||||
return json.dumps(vec)
|
||||
except Exception as e:
|
||||
print(f"\033[91m[WARN] {table_name}.{img_col} | 推理异常 [{filename}]: {type(e).__name__}: {e}\033[0m")
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 主入口
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
start = datetime.now()
|
||||
total_success = 0
|
||||
total_skip = 0
|
||||
|
||||
print("=" * 60)
|
||||
print("📦 全量历史图片向量初始化")
|
||||
print("=" * 60)
|
||||
print(f"图片目录: {UPLOADS_ROOT}")
|
||||
print(f"待处理表数: {len(TARGET_TABLES)}")
|
||||
print()
|
||||
|
||||
# 1. 初始化 Flask 应用上下文(加载 CLIP 模型)
|
||||
app = create_app()
|
||||
with app.app_context():
|
||||
load_clip_model()
|
||||
print("✅ CLIP 模型加载完成")
|
||||
print()
|
||||
|
||||
# 2. 遍历目标表
|
||||
for config in TARGET_TABLES:
|
||||
table_name = config["table"]
|
||||
img_col = config["img_col"]
|
||||
vec_col = config["vec_col"]
|
||||
|
||||
print(f"正在处理表: {table_name}, 字段: {img_col}")
|
||||
|
||||
# 3. 查询待清洗记录(只选未处理过的)
|
||||
sql = text(f"""
|
||||
SELECT id, {img_col}
|
||||
FROM {table_name}
|
||||
WHERE {img_col} IS NOT NULL
|
||||
AND {img_col} != '[]'
|
||||
AND ({vec_col} IS NULL)
|
||||
""")
|
||||
rows = db.session.execute(sql).fetchall()
|
||||
|
||||
if not rows:
|
||||
print(f"[{table_name}/{img_col}] ⏭ 无待处理记录")
|
||||
continue
|
||||
|
||||
print(f"\n[{table_name}/{img_col}] 📋 待处理: {len(rows)} 条")
|
||||
|
||||
# 4. 逐条处理
|
||||
processed = 0
|
||||
success_count = 0
|
||||
|
||||
for row in tqdm(rows, desc=f"{table_name}/{img_col}", unit="条"):
|
||||
record_id = row[0]
|
||||
raw_img = row[1]
|
||||
|
||||
try:
|
||||
vec_json = extract_first_valid_vector(raw_img, table_name, img_col)
|
||||
if vec_json is None:
|
||||
total_skip += 1
|
||||
continue
|
||||
|
||||
# 更新向量字段
|
||||
update_sql = text(f"""
|
||||
UPDATE {table_name} SET {vec_col} = :vec_str WHERE id = :id
|
||||
""")
|
||||
db.session.execute(update_sql, {"vec_str": vec_json, "id": record_id})
|
||||
success_count += 1
|
||||
|
||||
# 每 50 条提交一次
|
||||
if processed > 0 and processed % 50 == 0:
|
||||
db.session.commit()
|
||||
print(f"\n ✅ 已提交 {processed} 条")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\033[91m[WARN] {table_name}/{img_col} | ID={record_id} 处理异常: {type(e).__name__}: {e}\033[0m")
|
||||
# 关键:任何异常都不中断,只 continue 下一条
|
||||
db.session.rollback()
|
||||
continue
|
||||
finally:
|
||||
processed += 1
|
||||
|
||||
# 循环结束后补一次 commit(处理未凑满50条的剩余数据)
|
||||
try:
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
|
||||
total_success += success_count
|
||||
print(f"[{table_name}/{img_col}] ✅ 完成,成功 {success_count} 条 / 跳过 {len(rows) - success_count} 条")
|
||||
|
||||
# 5. 汇总报告
|
||||
elapsed = (datetime.now() - start).total_seconds()
|
||||
print()
|
||||
print("=" * 60)
|
||||
print(f"🏁 全部完成!总计耗时 {elapsed:.1f} 秒")
|
||||
print(f" ✅ 成功写入向量: {total_success} 条")
|
||||
print(f" ⏭ 无有效图片(跳过): {total_skip} 条")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -84,6 +84,9 @@
|
||||
|
||||
<el-button type="primary" plain @click="handleQuery">搜索</el-button>
|
||||
<el-button plain @click="resetQuery">重置</el-button>
|
||||
<el-button type="primary" plain @click="imageSearchVisible = true">
|
||||
<el-icon style="margin-right: 5px"><Picture /></el-icon>拍照识图
|
||||
</el-button>
|
||||
<el-popover
|
||||
v-model:visible="advancedFilterVisible"
|
||||
placement="bottom"
|
||||
@ -564,6 +567,12 @@
|
||||
/>
|
||||
</el-dialog>
|
||||
|
||||
<!-- 拍照识图弹窗 -->
|
||||
<ImageSearchDialog
|
||||
v-model="imageSearchVisible"
|
||||
@use="handleImageSearchUse"
|
||||
/>
|
||||
|
||||
<!-- 预警设置弹窗 -->
|
||||
<el-dialog v-model="warningDialog.visible" :title="warningDialog.title" width="500px" append-to-body destroy-on-close>
|
||||
<el-form ref="warningFormRef" :model="warningForm" :rules="warningRules" label-width="100px">
|
||||
@ -633,7 +642,7 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, onMounted, nextTick, computed } from 'vue';
|
||||
import { Plus, Document, Refresh, Setting, Rank, Camera, Link, Download, Bell, CircleCheck, Files, ZoomIn, Delete } from '@element-plus/icons-vue';
|
||||
import { Plus, Document, Refresh, Setting, Rank, Camera, Link, Download, Bell, CircleCheck, Files, ZoomIn, Delete, Picture } from '@element-plus/icons-vue';
|
||||
import { ElMessage, ElMessageBox, ElLoading } from 'element-plus';
|
||||
import type { FormInstance, FormRules } from 'element-plus';
|
||||
import { useUserStore } from '@/stores/user';
|
||||
@ -655,6 +664,8 @@ import {
|
||||
import { uploadFile, deleteFile } from '@/api/common/upload';
|
||||
import { usePasteUpload } from '@/hooks/usePasteUpload';
|
||||
import WebRtcCamera from '@/components/Camera/WebRtcCamera.vue';
|
||||
import ImageSearchDialog from '@/components/ImageSearchDialog.vue';
|
||||
import { imageSearch as imageSearchApi, type ImageSearchItem } from '@/api/common/upload';
|
||||
|
||||
const userStore = useUserStore();
|
||||
|
||||
@ -716,6 +727,7 @@ const isUploading = ref(false);
|
||||
|
||||
const tableSize = ref<'large' | 'default' | 'small'>('large');
|
||||
const advancedFilterVisible = ref(false);
|
||||
const imageSearchVisible = ref(false);
|
||||
const advancedConditions = ref([{ field: '', operator: '', value: '' }]);
|
||||
const fieldOptions = computed(() => {
|
||||
const allFields = [
|
||||
@ -1585,15 +1597,8 @@ const customUpload = async (options: any, targetField: 'generalImage' | 'general
|
||||
if (res.code === 200) {
|
||||
const newUrl = res.data.url
|
||||
form.value[targetField].push(newUrl)
|
||||
// 同步更新 fileList,触发 el-upload UI 刷新
|
||||
const fileObj = { name: newUrl.split('/').pop(), url: getImageUrl(newUrl) }
|
||||
if (targetField === 'generalImage') {
|
||||
fileListImage.value.push(fileObj)
|
||||
} else {
|
||||
fileListManual.value.push(fileObj)
|
||||
}
|
||||
ElMessage.success('上传成功')
|
||||
onSuccess(res)
|
||||
onSuccess(res) // el-upload v-model 自动更新 fileList,无需手动 push
|
||||
} else {
|
||||
ElMessage.error(res.msg || '上传失败');
|
||||
onError(new Error(res.msg))
|
||||
@ -1693,6 +1698,13 @@ const handleCameraConfirm = async (file: File) => {
|
||||
}
|
||||
};
|
||||
|
||||
// 以图搜图 - 使用物料
|
||||
const handleImageSearchUse = (item: ImageSearchItem) => {
|
||||
// 跳转到该物料详情页,或填充到表单
|
||||
router.push({ path: '/material/list', query: { keyword: item.spec_model } });
|
||||
ElMessage.success(`已定位物料: ${item.product_name}`);
|
||||
};
|
||||
|
||||
const addCondition = () => {
|
||||
advancedConditions.value.push({ field: '', operator: '', value: '' });
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user