Files
KCGL/inventory-backend/app/utils/ai_vision.py

132 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
AI Vision 模块 - CLIP Vision Encoder ONNX 推理
"""
import os
import numpy as np
from PIL import Image
import onnxruntime as ort
# ============================================================================
# 全局模型单例(项目启动时加载一次)
# ============================================================================
MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'clip_vision.onnx')
# 加载选项CPU 推理,禁用依赖库的启动开销
_session_options = ort.SessionOptions()
_session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_session: ort.InferenceSession = None
def load_clip_model():
"""启动时调用:全局加载 CLIP Vision 模型"""
global ort_session
if ort_session is not None:
return ort_session
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"CLIP Vision 模型未找到: {MODEL_PATH}")
ort_session = ort.InferenceSession(MODEL_PATH, sess_options=_session_options, providers=['CPUExecutionProvider'])
print(f"✅ [AI Vision] CLIP 模型加载成功: {MODEL_PATH}")
return ort_session
# ============================================================================
# CLIP 预处理常量
# ============================================================================
# ImageNet 标准归一化CLIP 官方)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# 模型输入尺寸
INPUT_SIZE = 224
def _center_crop_and_resize(image: Image.Image) -> Image.Image:
"""
CLIP 官方预处理:中心裁剪抗干扰
- 将图片最短边缩放到 224
- 从正中间切取 224x224 区域
"""
w, h = image.size
# 计算缩放后的目标尺寸
if w < h:
new_w = INPUT_SIZE
new_h = int(h * INPUT_SIZE / w)
else:
new_h = INPUT_SIZE
new_w = int(w * INPUT_SIZE / h)
# 缩放
image = image.resize((new_w, new_h), Image.BILINEAR)
# 中心裁剪
left = (new_w - INPUT_SIZE) // 2
top = (new_h - INPUT_SIZE) // 2
right = left + INPUT_SIZE
bottom = top + INPUT_SIZE
return image.crop((left, top, right, bottom))
def _normalize(image_np: np.ndarray) -> np.ndarray:
"""
对 224x224x3 图像进行 CLIP 标准归一化
image_np: shape (H, W, C), dtype uint8, 值域 [0, 255]
返回: shape (C, H, W), dtype float32, 值域 [0, 1]
"""
# HWC -> CHW
image_np = image_np.transpose(2, 0, 1).astype(np.float32) / 255.0
# 归一化
for i, (mean, std) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
image_np[i] = (image_np[i] - mean) / std
return image_np
# ============================================================================
# 主函数:提取图像 embedding
# ============================================================================
def get_image_embedding(image_path: str) -> list:
"""
提取图像的 512 维 CLIP embedding 向量
参数:
image_path: 图像文件路径
返回:
list: 512 维浮点向量
"""
if ort_session is None:
load_clip_model()
# 1. 图片预处理
image = Image.open(image_path).convert('RGB')
image = _center_crop_and_resize(image)
input_data = _normalize(np.array(image))
input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224]
# 2. 构造占位符输入 (关键修复)
dummy_ids = np.zeros((1, 77), dtype=np.int64)
dummy_mask = np.zeros((1, 77), dtype=np.int64)
# 3. 传入模型进行推理
# 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask'
# 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认
outputs = ort_session.run(
['image_embeds'],
{
'input_ids': dummy_ids,
'pixel_values': input_data.astype(np.float32),
'attention_mask': dummy_mask
}
)
return outputs[0][0].tolist()