# -*- coding: utf-8 -*- from __future__ import annotations """ 全量历史图片向量初始化脚本 功能:遍历配置表中所有历史图片字段,批量提取 CLIP 512 维向量并存回数据库。 用法:python scripts/init_all_vectors.py """ import os import json import sys from datetime import datetime from typing import List, Optional # 将项目根目录加入 Python 路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) from tqdm import tqdm from sqlalchemy import text # Flask 应用环境 from app import create_app from app.extensions import db from app.utils.ai_vision import get_image_embedding, load_clip_model # ============================================================================ # 业务配置:表 → 图片字段 → 向量字段 映射 (已全面修复) # ============================================================================ TARGET_TABLES = [ # 1. 基础物料 {"table": "material_base", "img_col": "product_image", "vec_col": "img_embedding"}, # 2. 采购入库 {"table": "stock_buy", "img_col": "arrival_photo", "vec_col": "arrival_image_embedding"}, {"table": "stock_buy", "img_col": "inspection_report", "vec_col": "qc_report_image_embedding"}, # 已修复: qc_report -> inspection_report # 3. 半成品入库 (新增) {"table": "stock_semi", "img_col": "arrival_photo", "vec_col": "arrival_image_embedding"}, {"table": "stock_semi", "img_col": "quality_report_link", "vec_col": "qc_report_image_embedding"}, # 4. 成品入库 (新增) {"table": "stock_product", "img_col": "product_photo", "vec_col": "arrival_image_embedding"}, {"table": "stock_product", "img_col": "quality_report_link", "vec_col": "qc_report_image_embedding"} # 注意:成品入库表还有一个 inspection_report_link,但由于数据库中成品表目前只加了两个向量字段, # 暂不将该字段加入遍历,以免覆盖 quality_report_link 的特征。 ] # 物理图片根目录(相对于 app 目录的相对路径 ../uploads/) APP_DIR = os.path.join(os.path.dirname(__file__), '..', 'app') UPLOADS_ROOT = os.path.abspath(os.path.join(APP_DIR, '..', 'uploads')) # ============================================================================ # 核心工具函数 # ============================================================================ def parse_img_field(raw_value: str) -> List[str]: """ 健壮解析图片字段,支持以下格式: - JSON 数组字符串: ["a.jpg", "b.jpg"] - 纯字符串单图片: "a.jpg" - 带 /api/v1/files/ 前缀: ["/api/v1/files/a.jpg"] 返回: 提取出的文件名列表 """ if not raw_value or (isinstance(raw_value, str) and not raw_value.strip()): return [] try: # 先尝试按 JSON 解析(处理 JSON 数组字符串) parsed = json.loads(raw_value) if isinstance(parsed, list): items = parsed else: items = [parsed] except (json.JSONDecodeError, TypeError): # JSON 解析失败,说明是纯字符串,直接按单图片处理 items = [raw_value.strip()] filenames = [] for item in items: if not item or not isinstance(item, str): continue item = item.strip() if not item: continue # 去掉可能的 /api/v1/files/ 前缀 filename = os.path.basename(item) filenames.append(filename) return filenames def build_local_path(filename: str) -> str: """ 将文件名拼装成本地绝对路径 """ return os.path.join(UPLOADS_ROOT, filename) def extract_first_valid_vector(raw_img_field: str, table_name: str, img_col: str) -> Optional[str]: """ 读取图片字段,从第一条有效图片提取向量,返回写入 DB 的 JSON 字符串。 如果所有图片均失败,返回 None。 """ filenames = parse_img_field(raw_img_field) if not filenames: return None for filename in filenames: local_path = build_local_path(filename) if not os.path.exists(local_path): print(f"\033[91m[WARN] {table_name}.{img_col} | 文件不存在: {local_path}\033[0m") continue try: vec = get_image_embedding(local_path) if vec is not None: return json.dumps(vec) except Exception as e: print(f"\033[91m[WARN] {table_name}.{img_col} | 推理异常 [{filename}]: {type(e).__name__}: {e}\033[0m") continue return None # ============================================================================ # 主入口 # ============================================================================ def main(): start = datetime.now() total_success = 0 total_skip = 0 print("=" * 60) print("📦 全量历史图片向量初始化") print("=" * 60) print(f"图片目录: {UPLOADS_ROOT}") print(f"待处理表数: {len(TARGET_TABLES)}") print() # 1. 初始化 Flask 应用上下文(加载 CLIP 模型) app = create_app() with app.app_context(): load_clip_model() print("✅ CLIP 模型加载完成") print() # 2. 遍历目标表 for config in TARGET_TABLES: table_name = config["table"] img_col = config["img_col"] vec_col = config["vec_col"] print(f"正在处理表: {table_name}, 字段: {img_col}") # 3. 查询待清洗记录(只选未处理过的) sql = text(f""" SELECT id, {img_col} FROM {table_name} WHERE {img_col} IS NOT NULL AND {img_col} != '[]' AND ({vec_col} IS NULL) """) rows = db.session.execute(sql).fetchall() if not rows: print(f"[{table_name}/{img_col}] ⏭ 无待处理记录") continue print(f"\n[{table_name}/{img_col}] 📋 待处理: {len(rows)} 条") # 4. 逐条处理 processed = 0 success_count = 0 for row in tqdm(rows, desc=f"{table_name}/{img_col}", unit="条"): record_id = row[0] raw_img = row[1] try: vec_json = extract_first_valid_vector(raw_img, table_name, img_col) if vec_json is None: total_skip += 1 continue # 更新向量字段 update_sql = text(f""" UPDATE {table_name} SET {vec_col} = :vec_str WHERE id = :id """) db.session.execute(update_sql, {"vec_str": vec_json, "id": record_id}) success_count += 1 # 每 50 条提交一次 if processed > 0 and processed % 50 == 0: db.session.commit() print(f"\n ✅ 已提交 {processed} 条") except Exception as e: print(f"\n\033[91m[WARN] {table_name}/{img_col} | ID={record_id} 处理异常: {type(e).__name__}: {e}\033[0m") # 关键:任何异常都不中断,只 continue 下一条 db.session.rollback() continue finally: processed += 1 # 循环结束后补一次 commit(处理未凑满50条的剩余数据) try: db.session.commit() except Exception: db.session.rollback() total_success += success_count print(f"[{table_name}/{img_col}] ✅ 完成,成功 {success_count} 条 / 跳过 {len(rows) - success_count} 条") # 5. 汇总报告 elapsed = (datetime.now() - start).total_seconds() print() print("=" * 60) print(f"🏁 全部完成!总计耗时 {elapsed:.1f} 秒") print(f" ✅ 成功写入向量: {total_success} 条") print(f" ⏭ 无有效图片(跳过): {total_skip} 条") print("=" * 60) if __name__ == "__main__": main()