# -*- coding: utf-8 -*- """ AI Vision 模块 - CLIP Vision Encoder ONNX 推理 """ import os import json import time import numpy as np from PIL import Image import onnxruntime as ort # ============================================================================ # 全局模型单例(项目启动时加载一次) # ============================================================================ MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'clip_vision.onnx') # 加载选项:CPU 推理,禁用依赖库的启动开销 _session_options = ort.SessionOptions() _session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL ort_session: ort.InferenceSession = None def load_clip_model(): """启动时调用:全局加载 CLIP Vision 模型""" global ort_session if ort_session is not None: return ort_session if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"CLIP Vision 模型未找到: {MODEL_PATH}") ort_session = ort.InferenceSession(MODEL_PATH, sess_options=_session_options, providers=['CPUExecutionProvider']) print(f"✅ [AI Vision] CLIP 模型加载成功: {MODEL_PATH}") return ort_session # ============================================================================ # CLIP 预处理常量 # ============================================================================ # ImageNet 标准归一化(CLIP 官方) IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # 模型输入尺寸 INPUT_SIZE = 224 def _center_crop_and_resize(image: Image.Image) -> Image.Image: """ CLIP 官方预处理:中心裁剪抗干扰 - 将图片最短边缩放到 224 - 从正中间切取 224x224 区域 """ w, h = image.size # 计算缩放后的目标尺寸 if w < h: new_w = INPUT_SIZE new_h = int(h * INPUT_SIZE / w) else: new_h = INPUT_SIZE new_w = int(w * INPUT_SIZE / h) # 缩放 image = image.resize((new_w, new_h), Image.BILINEAR) # 中心裁剪 left = (new_w - INPUT_SIZE) // 2 top = (new_h - INPUT_SIZE) // 2 right = left + INPUT_SIZE bottom = top + INPUT_SIZE return image.crop((left, top, right, bottom)) def _normalize(image_np: np.ndarray) -> np.ndarray: """ 对 224x224x3 图像进行 CLIP 标准归一化 image_np: shape (H, W, C), dtype uint8, 值域 [0, 255] 返回: shape (C, H, W), dtype float32, 值域 [0, 1] """ # HWC -> CHW image_np = image_np.transpose(2, 0, 1).astype(np.float32) / 255.0 # 归一化 for i, (mean, std) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)): image_np[i] = (image_np[i] - mean) / std return image_np # ============================================================================ # 主函数:提取图像 embedding # ============================================================================ def get_image_embedding(image_path: str) -> list: """ 提取图像的 512 维 CLIP embedding 向量 参数: image_path: 图像文件路径 返回: list: 512 维浮点向量 """ if ort_session is None: load_clip_model() # 1. 图片预处理 image = Image.open(image_path).convert('RGB') image = _center_crop_and_resize(image) input_data = _normalize(np.array(image)) input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224] # 2. 构造占位符输入 (关键修复) dummy_ids = np.zeros((1, 77), dtype=np.int64) dummy_mask = np.zeros((1, 77), dtype=np.int64) # 3. 传入模型进行推理 # 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask' # 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认 outputs = ort_session.run( ['image_embeds'], { 'input_ids': dummy_ids, 'pixel_values': input_data.astype(np.float32), 'attention_mask': dummy_mask } ) return outputs[0][0].tolist() # ============================================================================ # 通用向量提取工具:防呆、防错 # ============================================================================ def extract_and_embed(photo_source): if not photo_source: return None try: # 1. 提取基础字符串 photo_source_str = str(photo_source).strip() raw_path = "" # 尝试剥掉 JSON 外壳 try: parsed = json.loads(photo_source_str) if isinstance(parsed, list): raw_path = parsed[0] if parsed else "" elif isinstance(parsed, str): raw_path = parsed else: raw_path = str(parsed) except: raw_path = photo_source_str if not raw_path: return None # 2. 剥离出最纯净的文件名 (只取最后一段) pure_filename = raw_path.split('/')[-1] # 3. 【终极物理净化】强行抠掉所有多余的标点符号! # 哪怕传进来的是 123.jpg"] 或者是 "123.jpg",全部洗干净 pure_filename = pure_filename.replace('"', '').replace("'", "").replace('[', '').replace(']', '') # 4. 拼接真实的 Docker 物理路径 file_path = os.path.join('/app/uploads', pure_filename) # 5. 加入重试机制 (最多等 3 秒) max_retries = 6 for i in range(max_retries): if os.path.exists(file_path): # 文件找到了,开始提取向量 vec = get_image_embedding(file_path) if isinstance(vec, np.ndarray): return vec.tolist() return vec else: print(f"[AI 识图等待] 第 {i+1} 次尝试,未找到文件 {file_path},等待 0.5s...") time.sleep(0.5) print(f"[AI 识图警告] 彻底失败!经过等待依然未找到图片: {file_path}") except Exception as e: print(f"[AI 识图错误] 实时提取向量失败: {str(e)}") return None