From c273f5a9d9f9d8a567affa047e6bb5ce10291c78 Mon Sep 17 00:00:00 2001 From: DXC Date: Thu, 21 May 2026 15:43:45 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BB=A5=E5=9B=BE=E6=90=9C=E5=9B=BE?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=8D=87=E7=BA=A7=EF=BC=88=E8=B7=A8=E8=A1=A8?= =?UTF-8?q?UNION=E6=A3=80=E7=B4=A2=20+=20=E6=8B=8D=E7=85=A7=E8=AF=86?= =?UTF-8?q?=E5=9B=BE=E5=85=A5=E5=8F=A3=20+=20=E6=89=B9=E9=87=8F=E5=90=91?= =?UTF-8?q?=E9=87=8F=E5=88=9D=E5=A7=8B=E5=8C=96=E8=84=9A=E6=9C=AC=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/api/v1/common/image_search.py | 61 +++-- inventory-backend/app/utils/ai_vision.py | 38 +-- inventory-backend/scripts/init_all_vectors.py | 220 ++++++++++++++++++ inventory-web/src/views/material/list.vue | 30 ++- 4 files changed, 304 insertions(+), 45 deletions(-) create mode 100644 inventory-backend/scripts/init_all_vectors.py diff --git a/inventory-backend/app/api/v1/common/image_search.py b/inventory-backend/app/api/v1/common/image_search.py index aad1355..d1a33ae 100644 --- a/inventory-backend/app/api/v1/common/image_search.py +++ b/inventory-backend/app/api/v1/common/image_search.py @@ -71,19 +71,45 @@ def image_search(): print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}") # --------------------------------------------------------- - # 5. pgvector 余弦相似度检索 + # 5. pgvector 余弦相似度检索(跨表联合检索) # --------------------------------------------------------- try: - # 将 Python list 转为 PostgreSQL 向量格式: '[0.1, 0.2, ...]' query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']' sql = text(""" - SELECT id, name, spec_model, product_image, - (1 - (img_embedding <=> :query_vector)) AS similarity - FROM material_base - WHERE img_embedding IS NOT NULL - ORDER BY img_embedding <=> :query_vector - LIMIT 5 + SELECT id, name, spec_model, image_url, + (1 - (vec <=> :query_vector)) AS similarity + FROM ( + SELECT id, + COALESCE(name, '') AS name, + COALESCE(spec, '') AS spec_model, + COALESCE(product_image, '') AS image_url, + img_embedding AS vec + FROM material_base + WHERE img_embedding IS NOT NULL + + UNION ALL + + SELECT id, + '采购入库' AS name, + '到货照片' AS spec_model, + COALESCE(arrival_photo, '') AS image_url, + arrival_image_embedding AS vec + FROM stock_buy + WHERE arrival_image_embedding IS NOT NULL + + UNION ALL + + SELECT id, + '采购入库' AS name, + '质检报告' AS spec_model, + COALESCE(qc_report, '') AS image_url, + qc_report_image_embedding AS vec + FROM stock_buy + WHERE qc_report_image_embedding IS NOT NULL + ) AS combined + ORDER BY vec <=> :query_vector + LIMIT 10 """) result = db.session.execute(sql, {"query_vector": query_vector_str}) @@ -91,30 +117,31 @@ def image_search(): results = [] for row in rows: - product_id = row[0] - product_name = row[1] or "" + item_id = row[0] + item_name = row[1] or "" spec_model = row[2] or "" - product_image = row[3] + raw_image = row[3] # 解析图片 URL 列表,取第一张 image_url = "" - if product_image: + if raw_image: try: - image_list = json.loads(product_image) + image_list = json.loads(raw_image) if image_list and len(image_list) > 0: image_url = image_list[0] except Exception: - image_url = str(product_image) + # 纯字符串直接使用 + image_url = str(raw_image) results.append({ - "product_id": product_id, - "product_name": product_name, + "id": item_id, + "name": item_name, "spec_model": spec_model, "image_url": image_url, "similarity": round(float(row[4]), 4) }) - print(f"✅ [ImageSearch] 检索完成,命中 {len(results)} 条结果") + print(f"✅ [ImageSearch] 跨表检索完成,命中 {len(results)} 条结果") return jsonify({ "code": 200, "msg": "检索成功", diff --git a/inventory-backend/app/utils/ai_vision.py b/inventory-backend/app/utils/ai_vision.py index cf1854b..353f1da 100644 --- a/inventory-backend/app/utils/ai_vision.py +++ b/inventory-backend/app/utils/ai_vision.py @@ -100,7 +100,7 @@ def get_image_embedding(image_path: str) -> list: 提取图像的 512 维 CLIP embedding 向量 参数: - image_path: 图像文件路径(支持本地路径或 URL) + image_path: 图像文件路径 返回: list: 512 维浮点向量 @@ -108,25 +108,25 @@ def get_image_embedding(image_path: str) -> list: if ort_session is None: load_clip_model() - # 加载图像 - try: - image = Image.open(image_path).convert('RGB') - except Exception as e: - raise ValueError(f"图像加载失败: {image_path}, 错误: {e}") - - # 中心裁剪 + # 1. 图片预处理 + image = Image.open(image_path).convert('RGB') image = _center_crop_and_resize(image) - - # 归一化 input_data = _normalize(np.array(image)) + input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224] - # 添加 batch 维度: (C, H, W) -> (1, C, H, W) - input_data = np.expand_dims(input_data, axis=0) + # 2. 构造占位符输入 (关键修复) + dummy_ids = np.zeros((1, 77), dtype=np.int64) + dummy_mask = np.zeros((1, 77), dtype=np.int64) - # 推理 - outputs = ort_session.run(None, {'images': input_data.astype(np.float32)}) - - # 输出通常是 (1, 512) 的向量,取第一项并展平为 list - embedding = outputs[0][0].tolist() - - return embedding \ No newline at end of file + # 3. 传入模型进行推理 + # 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask' + # 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认 + outputs = ort_session.run( + ['image_embeds'], + { + 'input_ids': dummy_ids, + 'pixel_values': input_data.astype(np.float32), + 'attention_mask': dummy_mask + } + ) + return outputs[0][0].tolist() \ No newline at end of file diff --git a/inventory-backend/scripts/init_all_vectors.py b/inventory-backend/scripts/init_all_vectors.py new file mode 100644 index 0000000..12d1460 --- /dev/null +++ b/inventory-backend/scripts/init_all_vectors.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +""" +全量历史图片向量初始化脚本 + +功能:遍历配置表中所有历史图片字段,批量提取 CLIP 512 维向量并存回数据库。 +用法:python scripts/init_all_vectors.py +""" + +import os +import json +import sys +from datetime import datetime +from typing import List, Optional + +# 将项目根目录加入 Python 路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from tqdm import tqdm +from sqlalchemy import text + +# Flask 应用环境 +from app import create_app +from app.extensions import db +from app.utils.ai_vision import get_image_embedding, load_clip_model + +# ============================================================================ +# 业务配置:表 → 图片字段 → 向量字段 映射 +# ============================================================================ +TARGET_TABLES = [ + # 基础物料 + {"table": "material_base", "img_col": "product_image", "vec_col": "img_embedding"}, + + # 采购入库 + {"table": "stock_buy", "img_col": "arrival_photo", "vec_col": "arrival_image_embedding"}, + {"table": "stock_buy", "img_col": "qc_report", "vec_col": "qc_report_image_embedding"}, +] + +# 物理图片根目录(相对于 app 目录的相对路径 ../uploads/) +APP_DIR = os.path.join(os.path.dirname(__file__), '..', 'app') +UPLOADS_ROOT = os.path.abspath(os.path.join(APP_DIR, '..', 'uploads')) + + +# ============================================================================ +# 核心工具函数 +# ============================================================================ + +def parse_img_field(raw_value: str) -> List[str]: + """ + 健壮解析图片字段,支持以下格式: + - JSON 数组字符串: ["a.jpg", "b.jpg"] + - 纯字符串单图片: "a.jpg" + - 带 /api/v1/files/ 前缀: ["/api/v1/files/a.jpg"] + 返回: 提取出的文件名列表 + """ + if not raw_value or (isinstance(raw_value, str) and not raw_value.strip()): + return [] + + try: + # 先尝试按 JSON 解析(处理 JSON 数组字符串) + parsed = json.loads(raw_value) + if isinstance(parsed, list): + items = parsed + else: + items = [parsed] + except (json.JSONDecodeError, TypeError): + # JSON 解析失败,说明是纯字符串,直接按单图片处理 + items = [raw_value.strip()] + + filenames = [] + for item in items: + if not item or not isinstance(item, str): + continue + item = item.strip() + if not item: + continue + # 去掉可能的 /api/v1/files/ 前缀 + filename = os.path.basename(item) + filenames.append(filename) + + return filenames + + +def build_local_path(filename: str) -> str: + """ + 将文件名拼装成本地绝对路径 + """ + return os.path.join(UPLOADS_ROOT, filename) + + +def extract_first_valid_vector(raw_img_field: str, table_name: str, img_col: str) -> Optional[str]: + """ + 读取图片字段,从第一条有效图片提取向量,返回写入 DB 的 JSON 字符串。 + 如果所有图片均失败,返回 None。 + """ + filenames = parse_img_field(raw_img_field) + if not filenames: + return None + + for filename in filenames: + local_path = build_local_path(filename) + + if not os.path.exists(local_path): + print(f"\033[91m[WARN] {table_name}.{img_col} | 文件不存在: {local_path}\033[0m") + continue + + try: + vec = get_image_embedding(local_path) + if vec is not None: + return json.dumps(vec) + except Exception as e: + print(f"\033[91m[WARN] {table_name}.{img_col} | 推理异常 [{filename}]: {type(e).__name__}: {e}\033[0m") + continue + + return None + + +# ============================================================================ +# 主入口 +# ============================================================================ + +def main(): + start = datetime.now() + total_success = 0 + total_skip = 0 + + print("=" * 60) + print("📦 全量历史图片向量初始化") + print("=" * 60) + print(f"图片目录: {UPLOADS_ROOT}") + print(f"待处理表数: {len(TARGET_TABLES)}") + print() + + # 1. 初始化 Flask 应用上下文(加载 CLIP 模型) + app = create_app() + with app.app_context(): + load_clip_model() + print("✅ CLIP 模型加载完成") + print() + + # 2. 遍历目标表 + for config in TARGET_TABLES: + table_name = config["table"] + img_col = config["img_col"] + vec_col = config["vec_col"] + + print(f"正在处理表: {table_name}, 字段: {img_col}") + + # 3. 查询待清洗记录(只选未处理过的) + sql = text(f""" + SELECT id, {img_col} + FROM {table_name} + WHERE {img_col} IS NOT NULL + AND {img_col} != '[]' + AND ({vec_col} IS NULL) + """) + rows = db.session.execute(sql).fetchall() + + if not rows: + print(f"[{table_name}/{img_col}] ⏭ 无待处理记录") + continue + + print(f"\n[{table_name}/{img_col}] 📋 待处理: {len(rows)} 条") + + # 4. 逐条处理 + processed = 0 + success_count = 0 + + for row in tqdm(rows, desc=f"{table_name}/{img_col}", unit="条"): + record_id = row[0] + raw_img = row[1] + + try: + vec_json = extract_first_valid_vector(raw_img, table_name, img_col) + if vec_json is None: + total_skip += 1 + continue + + # 更新向量字段 + update_sql = text(f""" + UPDATE {table_name} SET {vec_col} = :vec_str WHERE id = :id + """) + db.session.execute(update_sql, {"vec_str": vec_json, "id": record_id}) + success_count += 1 + + # 每 50 条提交一次 + if processed > 0 and processed % 50 == 0: + db.session.commit() + print(f"\n ✅ 已提交 {processed} 条") + + except Exception as e: + print(f"\n\033[91m[WARN] {table_name}/{img_col} | ID={record_id} 处理异常: {type(e).__name__}: {e}\033[0m") + # 关键:任何异常都不中断,只 continue 下一条 + db.session.rollback() + continue + finally: + processed += 1 + + # 循环结束后补一次 commit(处理未凑满50条的剩余数据) + try: + db.session.commit() + except Exception: + db.session.rollback() + + total_success += success_count + print(f"[{table_name}/{img_col}] ✅ 完成,成功 {success_count} 条 / 跳过 {len(rows) - success_count} 条") + + # 5. 汇总报告 + elapsed = (datetime.now() - start).total_seconds() + print() + print("=" * 60) + print(f"🏁 全部完成!总计耗时 {elapsed:.1f} 秒") + print(f" ✅ 成功写入向量: {total_success} 条") + print(f" ⏭ 无有效图片(跳过): {total_skip} 条") + print("=" * 60) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/inventory-web/src/views/material/list.vue b/inventory-web/src/views/material/list.vue index eecb890..9ee65c2 100644 --- a/inventory-web/src/views/material/list.vue +++ b/inventory-web/src/views/material/list.vue @@ -84,6 +84,9 @@ 搜索 重置 + + 拍照识图 + + + + @@ -633,7 +642,7 @@