Merge remote-tracking branch 'origin/3.0AI添加' into 3.0AI添加

This commit is contained in:
dxc
2026-05-21 18:29:48 +08:00
13 changed files with 1072 additions and 34 deletions

View File

@ -90,6 +90,17 @@ def create_app():
except ImportError as e:
print(f"❌ 错误: Upload 模块导入失败: {e}")
# -----------------------------------------------------
# 2.4 注册以图搜图模块 (Image Search)
# -----------------------------------------------------
try:
from app.api.v1.common.image_search import image_search_bp
app.register_blueprint(image_search_bp, url_prefix='/api/v1/common')
app.register_blueprint(image_search_bp, url_prefix='/api/common', name='image_search_legacy')
print("✅ Image Search 模块注册成功")
except ImportError as e:
print(f"❌ 错误: Image Search 模块导入失败: {e}")
# -----------------------------------------------------
# 2.4 注册业务操作模块 (Transactions - 借还/维修/报废)
# -----------------------------------------------------

View File

@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
"""
以图搜图 API - CLIP Vision Embedding + pgvector 余弦距离检索
"""
import os
import uuid
import json
from flask import Blueprint, request, jsonify
from sqlalchemy import text
from app.extensions import db
from app.utils.ai_vision import load_clip_model, get_image_embedding
# 注册蓝图
image_search_bp = Blueprint('image_search', __name__)
# ============================================================================
# POST /api/v1/common/image-search
# 以图搜图:上传图片 → CLIP embedding → pgvector 余弦相似度检索
# ============================================================================
@image_search_bp.route('/image-search', methods=['POST'])
def image_search():
# ---------------------------------------------------------
# 1. 检查文件
# ---------------------------------------------------------
if 'file' not in request.files:
return jsonify({"code": 400, "msg": "未找到图片文件"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"code": 400, "msg": "未选择文件"}), 400
# ---------------------------------------------------------
# 2. 安全保存临时文件
# ---------------------------------------------------------
ext = file.filename.rsplit('.', 1)[-1].lower()
if ext not in {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}:
return jsonify({"code": 400, "msg": "不支持的图片格式"}), 400
tmp_filename = f"{uuid.uuid4().hex}.{ext}"
tmp_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'uploads')
os.makedirs(tmp_dir, exist_ok=True)
tmp_path = os.path.join(tmp_dir, tmp_filename)
try:
file.save(tmp_path)
print(f"💾 [ImageSearch] 临时文件已保存: {tmp_path}")
# ---------------------------------------------------------
# 3. 提取 CLIP embedding
# ---------------------------------------------------------
load_clip_model()
embedding = get_image_embedding(tmp_path)
print(f"✅ [ImageSearch] Embedding 提取成功,维度: {len(embedding)}")
except Exception as e:
print(f"❌ [ImageSearch] 图像处理失败: {e}")
return jsonify({"code": 500, "msg": f"图像处理失败: {str(e)}"}), 500
finally:
# ---------------------------------------------------------
# 4. 无论成功与否,都删除临时文件
# ---------------------------------------------------------
if os.path.exists(tmp_path):
try:
os.remove(tmp_path)
print(f"🗑️ [ImageSearch] 临时文件已清理: {tmp_path}")
except Exception as e:
print(f"⚠️ [ImageSearch] 临时文件删除失败: {e}")
# ---------------------------------------------------------
# 5. pgvector 余弦相似度检索(跨表联合检索)
# ---------------------------------------------------------
try:
query_vector_str = '[' + ','.join(str(v) for v in embedding) + ']'
sql = text("""
SELECT id, name, spec_model, image_url,
(1 - (vec <=> :query_vector)) AS similarity
FROM (
SELECT id,
COALESCE(name, '') AS name,
COALESCE(spec, '') AS spec_model,
COALESCE(product_image, '') AS image_url,
img_embedding AS vec
FROM material_base
WHERE img_embedding IS NOT NULL
UNION ALL
SELECT id,
'采购入库' AS name,
'到货照片' AS spec_model,
COALESCE(arrival_photo, '') AS image_url,
arrival_image_embedding AS vec
FROM stock_buy
WHERE arrival_image_embedding IS NOT NULL
UNION ALL
SELECT id,
'采购入库' AS name,
'质检报告' AS spec_model,
COALESCE(qc_report, '') AS image_url,
qc_report_image_embedding AS vec
FROM stock_buy
WHERE qc_report_image_embedding IS NOT NULL
) AS combined
ORDER BY vec <=> :query_vector
LIMIT 10
""")
result = db.session.execute(sql, {"query_vector": query_vector_str})
rows = result.fetchall()
results = []
for row in rows:
item_id = row[0]
item_name = row[1] or ""
spec_model = row[2] or ""
raw_image = row[3]
# 解析图片 URL 列表,取第一张
image_url = ""
if raw_image:
try:
image_list = json.loads(raw_image)
if image_list and len(image_list) > 0:
image_url = image_list[0]
except Exception:
# 纯字符串直接使用
image_url = str(raw_image)
results.append({
"id": item_id,
"name": item_name,
"spec_model": spec_model,
"image_url": image_url,
"similarity": round(float(row[4]), 4)
})
print(f"✅ [ImageSearch] 跨表检索完成,命中 {len(results)} 条结果")
return jsonify({
"code": 200,
"msg": "检索成功",
"data": results
})
except Exception as e:
print(f"❌ [ImageSearch] 数据库检索失败: {e}")
return jsonify({"code": 500, "msg": f"检索失败: {str(e)}"}), 500

View File

@ -11,6 +11,8 @@ Dify 智能客服权限服务层
- 跨模块越权查询:直接阻断,返回角色专属的错误信息给大模型
"""
from typing import Optional
from flask import g, current_app
from flask_jwt_extended import decode_token
from app.models.system import SysRolePermission
@ -185,7 +187,7 @@ class DifyPermissionService:
返回:
{
'blocked': bool, # 是否被拦截
'message': str | None, # AI 应返回给用户的错误信息(如果有)
'message': Optional[str], # AI 应返回给用户的错误信息(如果有)
}
"""
if DifyPermissionService.is_super_admin(role):

View File

@ -20,6 +20,8 @@ import logging
from threading import Thread
from datetime import datetime
from typing import Optional
from openpyxl import Workbook
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
@ -346,7 +348,7 @@ def get_task_status(task_id: str) -> dict:
# 获取导出文件路径(供下载接口调用)
# =============================================================================
def get_export_filepath(task_id: str) -> str | None:
def get_export_filepath(task_id: str) -> Optional[str]:
"""
根据 task_id 返回已生成文件的完整路径。
未完成或不存在返回 None。

View File

@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
"""
AI Vision 模块 - CLIP Vision Encoder ONNX 推理
"""
import os
import numpy as np
from PIL import Image
import onnxruntime as ort
# ============================================================================
# 全局模型单例(项目启动时加载一次)
# ============================================================================
MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'clip_vision.onnx')
# 加载选项CPU 推理,禁用依赖库的启动开销
_session_options = ort.SessionOptions()
_session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_session: ort.InferenceSession = None
def load_clip_model():
"""启动时调用:全局加载 CLIP Vision 模型"""
global ort_session
if ort_session is not None:
return ort_session
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"CLIP Vision 模型未找到: {MODEL_PATH}")
ort_session = ort.InferenceSession(MODEL_PATH, sess_options=_session_options, providers=['CPUExecutionProvider'])
print(f"✅ [AI Vision] CLIP 模型加载成功: {MODEL_PATH}")
return ort_session
# ============================================================================
# CLIP 预处理常量
# ============================================================================
# ImageNet 标准归一化CLIP 官方)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# 模型输入尺寸
INPUT_SIZE = 224
def _center_crop_and_resize(image: Image.Image) -> Image.Image:
"""
CLIP 官方预处理:中心裁剪抗干扰
- 将图片最短边缩放到 224
- 从正中间切取 224x224 区域
"""
w, h = image.size
# 计算缩放后的目标尺寸
if w < h:
new_w = INPUT_SIZE
new_h = int(h * INPUT_SIZE / w)
else:
new_h = INPUT_SIZE
new_w = int(w * INPUT_SIZE / h)
# 缩放
image = image.resize((new_w, new_h), Image.BILINEAR)
# 中心裁剪
left = (new_w - INPUT_SIZE) // 2
top = (new_h - INPUT_SIZE) // 2
right = left + INPUT_SIZE
bottom = top + INPUT_SIZE
return image.crop((left, top, right, bottom))
def _normalize(image_np: np.ndarray) -> np.ndarray:
"""
对 224x224x3 图像进行 CLIP 标准归一化
image_np: shape (H, W, C), dtype uint8, 值域 [0, 255]
返回: shape (C, H, W), dtype float32, 值域 [0, 1]
"""
# HWC -> CHW
image_np = image_np.transpose(2, 0, 1).astype(np.float32) / 255.0
# 归一化
for i, (mean, std) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
image_np[i] = (image_np[i] - mean) / std
return image_np
# ============================================================================
# 主函数:提取图像 embedding
# ============================================================================
def get_image_embedding(image_path: str) -> list:
"""
提取图像的 512 维 CLIP embedding 向量
参数:
image_path: 图像文件路径
返回:
list: 512 维浮点向量
"""
if ort_session is None:
load_clip_model()
# 1. 图片预处理
image = Image.open(image_path).convert('RGB')
image = _center_crop_and_resize(image)
input_data = _normalize(np.array(image))
input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224]
# 2. 构造占位符输入 (关键修复)
dummy_ids = np.zeros((1, 77), dtype=np.int64)
dummy_mask = np.zeros((1, 77), dtype=np.int64)
# 3. 传入模型进行推理
# 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask'
# 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认
outputs = ort_session.run(
['image_embeds'],
{
'input_ids': dummy_ids,
'pixel_values': input_data.astype(np.float32),
'attention_mask': dummy_mask
}
)
return outputs[0][0].tolist()