132 lines
4.0 KiB
Python
132 lines
4.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI Vision 模块 - CLIP Vision Encoder ONNX 推理
|
||
"""
|
||
|
||
import os
|
||
import numpy as np
|
||
from PIL import Image
|
||
import onnxruntime as ort
|
||
|
||
# ============================================================================
|
||
# 全局模型单例(项目启动时加载一次)
|
||
# ============================================================================
|
||
|
||
MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'clip_vision.onnx')
|
||
|
||
# 加载选项:CPU 推理,禁用依赖库的启动开销
|
||
_session_options = ort.SessionOptions()
|
||
_session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
|
||
ort_session: ort.InferenceSession = None
|
||
|
||
|
||
def load_clip_model():
|
||
"""启动时调用:全局加载 CLIP Vision 模型"""
|
||
global ort_session
|
||
if ort_session is not None:
|
||
return ort_session
|
||
|
||
if not os.path.exists(MODEL_PATH):
|
||
raise FileNotFoundError(f"CLIP Vision 模型未找到: {MODEL_PATH}")
|
||
|
||
ort_session = ort.InferenceSession(MODEL_PATH, sess_options=_session_options, providers=['CPUExecutionProvider'])
|
||
print(f"✅ [AI Vision] CLIP 模型加载成功: {MODEL_PATH}")
|
||
return ort_session
|
||
|
||
|
||
# ============================================================================
|
||
# CLIP 预处理常量
|
||
# ============================================================================
|
||
|
||
# ImageNet 标准归一化(CLIP 官方)
|
||
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
||
IMAGENET_STD = [0.229, 0.224, 0.225]
|
||
|
||
# 模型输入尺寸
|
||
INPUT_SIZE = 224
|
||
|
||
|
||
def _center_crop_and_resize(image: Image.Image) -> Image.Image:
|
||
"""
|
||
CLIP 官方预处理:中心裁剪抗干扰
|
||
- 将图片最短边缩放到 224
|
||
- 从正中间切取 224x224 区域
|
||
"""
|
||
w, h = image.size
|
||
|
||
# 计算缩放后的目标尺寸
|
||
if w < h:
|
||
new_w = INPUT_SIZE
|
||
new_h = int(h * INPUT_SIZE / w)
|
||
else:
|
||
new_h = INPUT_SIZE
|
||
new_w = int(w * INPUT_SIZE / h)
|
||
|
||
# 缩放
|
||
image = image.resize((new_w, new_h), Image.BILINEAR)
|
||
|
||
# 中心裁剪
|
||
left = (new_w - INPUT_SIZE) // 2
|
||
top = (new_h - INPUT_SIZE) // 2
|
||
right = left + INPUT_SIZE
|
||
bottom = top + INPUT_SIZE
|
||
|
||
return image.crop((left, top, right, bottom))
|
||
|
||
|
||
def _normalize(image_np: np.ndarray) -> np.ndarray:
|
||
"""
|
||
对 224x224x3 图像进行 CLIP 标准归一化
|
||
image_np: shape (H, W, C), dtype uint8, 值域 [0, 255]
|
||
返回: shape (C, H, W), dtype float32, 值域 [0, 1]
|
||
"""
|
||
# HWC -> CHW
|
||
image_np = image_np.transpose(2, 0, 1).astype(np.float32) / 255.0
|
||
|
||
# 归一化
|
||
for i, (mean, std) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
|
||
image_np[i] = (image_np[i] - mean) / std
|
||
|
||
return image_np
|
||
|
||
|
||
# ============================================================================
|
||
# 主函数:提取图像 embedding
|
||
# ============================================================================
|
||
|
||
def get_image_embedding(image_path: str) -> list:
|
||
"""
|
||
提取图像的 512 维 CLIP embedding 向量
|
||
|
||
参数:
|
||
image_path: 图像文件路径
|
||
|
||
返回:
|
||
list: 512 维浮点向量
|
||
"""
|
||
if ort_session is None:
|
||
load_clip_model()
|
||
|
||
# 1. 图片预处理
|
||
image = Image.open(image_path).convert('RGB')
|
||
image = _center_crop_and_resize(image)
|
||
input_data = _normalize(np.array(image))
|
||
input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224]
|
||
|
||
# 2. 构造占位符输入 (关键修复)
|
||
dummy_ids = np.zeros((1, 77), dtype=np.int64)
|
||
dummy_mask = np.zeros((1, 77), dtype=np.int64)
|
||
|
||
# 3. 传入模型进行推理
|
||
# 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask'
|
||
# 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认
|
||
outputs = ort_session.run(
|
||
['image_embeds'],
|
||
{
|
||
'input_ids': dummy_ids,
|
||
'pixel_values': input_data.astype(np.float32),
|
||
'attention_mask': dummy_mask
|
||
}
|
||
)
|
||
return outputs[0][0].tolist() |