feat: 以图搜图功能升级(跨表UNION检索 + 拍照识图入口 + 批量向量初始化脚本)

This commit is contained in:
DXC
2026-05-21 15:43:45 +08:00
parent 1a7c06f197
commit c273f5a9d9
4 changed files with 304 additions and 45 deletions

View File

@ -0,0 +1,220 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
"""
全量历史图片向量初始化脚本
功能:遍历配置表中所有历史图片字段,批量提取 CLIP 512 维向量并存回数据库。
用法python scripts/init_all_vectors.py
"""
import os
import json
import sys
from datetime import datetime
from typing import List, Optional
# 将项目根目录加入 Python 路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from tqdm import tqdm
from sqlalchemy import text
# Flask 应用环境
from app import create_app
from app.extensions import db
from app.utils.ai_vision import get_image_embedding, load_clip_model
# ============================================================================
# 业务配置:表 → 图片字段 → 向量字段 映射
# ============================================================================
TARGET_TABLES = [
# 基础物料
{"table": "material_base", "img_col": "product_image", "vec_col": "img_embedding"},
# 采购入库
{"table": "stock_buy", "img_col": "arrival_photo", "vec_col": "arrival_image_embedding"},
{"table": "stock_buy", "img_col": "qc_report", "vec_col": "qc_report_image_embedding"},
]
# 物理图片根目录(相对于 app 目录的相对路径 ../uploads/
APP_DIR = os.path.join(os.path.dirname(__file__), '..', 'app')
UPLOADS_ROOT = os.path.abspath(os.path.join(APP_DIR, '..', 'uploads'))
# ============================================================================
# 核心工具函数
# ============================================================================
def parse_img_field(raw_value: str) -> List[str]:
"""
健壮解析图片字段,支持以下格式:
- JSON 数组字符串: ["a.jpg", "b.jpg"]
- 纯字符串单图片: "a.jpg"
- 带 /api/v1/files/ 前缀: ["/api/v1/files/a.jpg"]
返回: 提取出的文件名列表
"""
if not raw_value or (isinstance(raw_value, str) and not raw_value.strip()):
return []
try:
# 先尝试按 JSON 解析(处理 JSON 数组字符串)
parsed = json.loads(raw_value)
if isinstance(parsed, list):
items = parsed
else:
items = [parsed]
except (json.JSONDecodeError, TypeError):
# JSON 解析失败,说明是纯字符串,直接按单图片处理
items = [raw_value.strip()]
filenames = []
for item in items:
if not item or not isinstance(item, str):
continue
item = item.strip()
if not item:
continue
# 去掉可能的 /api/v1/files/ 前缀
filename = os.path.basename(item)
filenames.append(filename)
return filenames
def build_local_path(filename: str) -> str:
"""
将文件名拼装成本地绝对路径
"""
return os.path.join(UPLOADS_ROOT, filename)
def extract_first_valid_vector(raw_img_field: str, table_name: str, img_col: str) -> Optional[str]:
"""
读取图片字段,从第一条有效图片提取向量,返回写入 DB 的 JSON 字符串。
如果所有图片均失败,返回 None。
"""
filenames = parse_img_field(raw_img_field)
if not filenames:
return None
for filename in filenames:
local_path = build_local_path(filename)
if not os.path.exists(local_path):
print(f"\033[91m[WARN] {table_name}.{img_col} | 文件不存在: {local_path}\033[0m")
continue
try:
vec = get_image_embedding(local_path)
if vec is not None:
return json.dumps(vec)
except Exception as e:
print(f"\033[91m[WARN] {table_name}.{img_col} | 推理异常 [{filename}]: {type(e).__name__}: {e}\033[0m")
continue
return None
# ============================================================================
# 主入口
# ============================================================================
def main():
start = datetime.now()
total_success = 0
total_skip = 0
print("=" * 60)
print("📦 全量历史图片向量初始化")
print("=" * 60)
print(f"图片目录: {UPLOADS_ROOT}")
print(f"待处理表数: {len(TARGET_TABLES)}")
print()
# 1. 初始化 Flask 应用上下文(加载 CLIP 模型)
app = create_app()
with app.app_context():
load_clip_model()
print("✅ CLIP 模型加载完成")
print()
# 2. 遍历目标表
for config in TARGET_TABLES:
table_name = config["table"]
img_col = config["img_col"]
vec_col = config["vec_col"]
print(f"正在处理表: {table_name}, 字段: {img_col}")
# 3. 查询待清洗记录(只选未处理过的)
sql = text(f"""
SELECT id, {img_col}
FROM {table_name}
WHERE {img_col} IS NOT NULL
AND {img_col} != '[]'
AND ({vec_col} IS NULL)
""")
rows = db.session.execute(sql).fetchall()
if not rows:
print(f"[{table_name}/{img_col}] ⏭ 无待处理记录")
continue
print(f"\n[{table_name}/{img_col}] 📋 待处理: {len(rows)}")
# 4. 逐条处理
processed = 0
success_count = 0
for row in tqdm(rows, desc=f"{table_name}/{img_col}", unit=""):
record_id = row[0]
raw_img = row[1]
try:
vec_json = extract_first_valid_vector(raw_img, table_name, img_col)
if vec_json is None:
total_skip += 1
continue
# 更新向量字段
update_sql = text(f"""
UPDATE {table_name} SET {vec_col} = :vec_str WHERE id = :id
""")
db.session.execute(update_sql, {"vec_str": vec_json, "id": record_id})
success_count += 1
# 每 50 条提交一次
if processed > 0 and processed % 50 == 0:
db.session.commit()
print(f"\n ✅ 已提交 {processed}")
except Exception as e:
print(f"\n\033[91m[WARN] {table_name}/{img_col} | ID={record_id} 处理异常: {type(e).__name__}: {e}\033[0m")
# 关键:任何异常都不中断,只 continue 下一条
db.session.rollback()
continue
finally:
processed += 1
# 循环结束后补一次 commit处理未凑满50条的剩余数据
try:
db.session.commit()
except Exception:
db.session.rollback()
total_success += success_count
print(f"[{table_name}/{img_col}] ✅ 完成,成功 {success_count} 条 / 跳过 {len(rows) - success_count}")
# 5. 汇总报告
elapsed = (datetime.now() - start).total_seconds()
print()
print("=" * 60)
print(f"🏁 全部完成!总计耗时 {elapsed:.1f}")
print(f" ✅ 成功写入向量: {total_success}")
print(f" ⏭ 无有效图片(跳过): {total_skip}")
print("=" * 60)
if __name__ == "__main__":
main()