280 lines
9.2 KiB
Python
280 lines
9.2 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI Vision 模块 - CLIP Vision Encoder ONNX 推理
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import numpy as np
|
||
from PIL import Image
|
||
import onnxruntime as ort
|
||
import cv2
|
||
|
||
# ============================================================================
|
||
# 全局模型单例(项目启动时加载一次)
|
||
# ============================================================================
|
||
|
||
MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'clip_vision.onnx')
|
||
|
||
# 加载选项:CPU 推理,禁用依赖库的启动开销
|
||
_session_options = ort.SessionOptions()
|
||
_session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
|
||
ort_session: ort.InferenceSession = None
|
||
|
||
|
||
def load_clip_model():
|
||
"""启动时调用:全局加载 CLIP Vision 模型"""
|
||
global ort_session
|
||
if ort_session is not None:
|
||
return ort_session
|
||
|
||
if not os.path.exists(MODEL_PATH):
|
||
raise FileNotFoundError(f"CLIP Vision 模型未找到: {MODEL_PATH}")
|
||
|
||
ort_session = ort.InferenceSession(MODEL_PATH, sess_options=_session_options, providers=['CPUExecutionProvider'])
|
||
print(f"✅ [AI Vision] CLIP 模型加载成功: {MODEL_PATH}")
|
||
return ort_session
|
||
|
||
|
||
# ============================================================================
|
||
# CLIP 预处理常量
|
||
# ============================================================================
|
||
|
||
# ImageNet 标准归一化(CLIP 官方)
|
||
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
||
IMAGENET_STD = [0.229, 0.224, 0.225]
|
||
|
||
# 模型输入尺寸
|
||
INPUT_SIZE = 224
|
||
|
||
# ============================================================================
|
||
# 背景去除配置:HSV 色彩空间阈值
|
||
# ============================================================================
|
||
# OpenCV HSV: H∈[0,180], S∈[0,255], V∈[0,255]
|
||
# 注意:OpenCV 中 H 通道范围是 0-180(是 OpenCV 自己的标准,和美术的 0-360 对应)
|
||
|
||
# 绿色背景阈值(工业绿幕常用色)
|
||
# H: 35~85 对应绿色谱(浅绿到深绿)
|
||
# S: 低饱和度(35)到高饱和度(255)
|
||
# V: 明暗均可(30~255)
|
||
BG_GREEN_LOWER = np.array([35, 35, 30])
|
||
BG_GREEN_UPPER = np.array([90, 255, 255])
|
||
|
||
# 白色/浅色背景阈值(高明度、低饱和度区域)
|
||
# H: 不限制(0~180),只看 S 和 V
|
||
# S: 很低的饱和度(0~35)→ 接近纯灰/白色
|
||
# V: 高明度(180~255)
|
||
BG_WHITE_LOWER = np.array([0, 0, 180])
|
||
BG_WHITE_UPPER = np.array([180, 40, 255])
|
||
|
||
# 中性灰填充色(BGR → 转换后 RGB 也是 128,128,128)
|
||
NEUTRAL_GRAY_BGR = (128, 128, 128)
|
||
|
||
|
||
def _remove_background(image: Image.Image) -> Image.Image:
|
||
"""
|
||
利用 OpenCV HSV 色彩空间识别并替换背景为中性灰
|
||
|
||
支持两种背景类型:
|
||
1. 工业绿幕/绿色背景(H: 35~90)
|
||
2. 白色/浅色背景(高亮度、低饱和度)
|
||
|
||
逻辑:
|
||
- 将 PIL Image 转为 OpenCV 格式 (RGB → BGR)
|
||
- 转 HSV,分别生成绿色掩码和白色掩码
|
||
- 合并掩码后,按掩码将背景区域替换为中性灰
|
||
- 还原为 PIL Image (BGR → RGB) 返回
|
||
|
||
参数:
|
||
image: PIL Image (RGB, uint8)
|
||
|
||
返回:
|
||
处理后的 PIL Image (RGB, uint8)
|
||
"""
|
||
# PIL (RGB) → OpenCV (BGR)
|
||
img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||
|
||
# 转入 HSV 色彩空间
|
||
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
|
||
|
||
# 生成掩码 1:绿色背景
|
||
mask_green = cv2.inRange(hsv, BG_GREEN_LOWER, BG_GREEN_UPPER)
|
||
|
||
# 生成掩码 2:白色/浅色背景
|
||
mask_white = cv2.inRange(hsv, BG_WHITE_LOWER, BG_WHITE_UPPER)
|
||
|
||
# 合并掩码(任意一种背景都替换)
|
||
mask_combined = cv2.bitwise_or(mask_green, mask_white)
|
||
|
||
# 形态学处理:消除噪点(小面积背景噪点填平)
|
||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||
mask_combined = cv2.morphologyEx(mask_combined, cv2.MORPH_CLOSE, kernel)
|
||
mask_combined = cv2.morphologyEx(mask_combined, cv2.MORPH_OPEN, kernel)
|
||
|
||
# 背景替换:将掩码区域填充为中性灰
|
||
# 其中 mask_combined=255 的区域为背景,替换为 NEUTRAL_GRAY_BGR
|
||
img_bgr_no_bg = img_bgr.copy()
|
||
img_bgr_no_bg[mask_combined > 0] = NEUTRAL_GRAY_BGR
|
||
|
||
# OpenCV (BGR) → PIL (RGB)
|
||
result = Image.fromarray(cv2.cvtColor(img_bgr_no_bg, cv2.COLOR_BGR2RGB))
|
||
|
||
return result
|
||
|
||
|
||
def _letterbox_image(image: Image.Image, size: int = 224) -> Image.Image:
|
||
"""
|
||
Letterbox 预处理:等比例缩放 + 灰色填充,保持内容不变形
|
||
|
||
- 将原图最长边缩放到 224
|
||
- 短边按相同比例缩放
|
||
- 不足部分用 RGB(128,128,128) 灰色填充至 224x224
|
||
|
||
参数:
|
||
image: PIL Image 对象
|
||
size: 目标尺寸,默认 224
|
||
|
||
返回:
|
||
224x224 PIL Image
|
||
"""
|
||
w, h = image.size
|
||
|
||
# 计算缩放比例,使最长边等于 size
|
||
scale = size / max(w, h)
|
||
new_w = int(w * scale)
|
||
new_h = int(h * scale)
|
||
|
||
# 等比例缩放
|
||
resized = image.resize((new_w, new_h), Image.LANCZOS)
|
||
|
||
# 创建灰色画布
|
||
canvas = Image.new('RGB', (size, size), (128, 128, 128))
|
||
|
||
# 将缩放后的图片粘贴到画布正中央
|
||
paste_x = (size - new_w) // 2
|
||
paste_y = (size - new_h) // 2
|
||
canvas.paste(resized, (paste_x, paste_y))
|
||
|
||
return canvas
|
||
|
||
|
||
def _normalize(image_np: np.ndarray) -> np.ndarray:
|
||
"""
|
||
对 224x224x3 图像进行 CLIP 标准归一化
|
||
image_np: shape (H, W, C), dtype uint8, 值域 [0, 255]
|
||
返回: shape (C, H, W), dtype float32, 值域 [0, 1]
|
||
"""
|
||
# HWC -> CHW
|
||
image_np = image_np.transpose(2, 0, 1).astype(np.float32) / 255.0
|
||
|
||
# 归一化
|
||
for i, (mean, std) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
|
||
image_np[i] = (image_np[i] - mean) / std
|
||
|
||
return image_np
|
||
|
||
|
||
# ============================================================================
|
||
# 主函数:提取图像 embedding
|
||
# ============================================================================
|
||
|
||
def get_image_embedding(image_path: str) -> list:
|
||
"""
|
||
提取图像的 512 维 CLIP embedding 向量
|
||
|
||
参数:
|
||
image_path: 图像文件路径
|
||
|
||
返回:
|
||
list: 512 维浮点向量
|
||
"""
|
||
if ort_session is None:
|
||
load_clip_model()
|
||
|
||
# 1. 图片预处理
|
||
# Step 1: 背景去除(HSV 色彩空间,绿色/白色背景 → 中性灰替换)
|
||
image = Image.open(image_path).convert('RGB')
|
||
image = _remove_background(image)
|
||
|
||
# Step 2: Letterbox 等比例缩放(保持内容不变形)
|
||
image = _letterbox_image(image, INPUT_SIZE)
|
||
input_data = _normalize(np.array(image))
|
||
input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224]
|
||
|
||
# 2. 构造占位符输入 (关键修复)
|
||
dummy_ids = np.zeros((1, 77), dtype=np.int64)
|
||
dummy_mask = np.zeros((1, 77), dtype=np.int64)
|
||
|
||
# 3. 传入模型进行推理
|
||
# 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask'
|
||
# 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认
|
||
outputs = ort_session.run(
|
||
['image_embeds'],
|
||
{
|
||
'input_ids': dummy_ids,
|
||
'pixel_values': input_data.astype(np.float32),
|
||
'attention_mask': dummy_mask
|
||
}
|
||
)
|
||
return outputs[0][0].tolist()
|
||
|
||
|
||
# ============================================================================
|
||
# 通用向量提取工具:防呆、防错
|
||
# ============================================================================
|
||
|
||
def extract_and_embed(photo_source):
|
||
if not photo_source:
|
||
return None
|
||
try:
|
||
# 1. 提取基础字符串
|
||
photo_source_str = str(photo_source).strip()
|
||
raw_path = ""
|
||
|
||
# 尝试剥掉 JSON 外壳
|
||
try:
|
||
parsed = json.loads(photo_source_str)
|
||
if isinstance(parsed, list):
|
||
raw_path = parsed[0] if parsed else ""
|
||
elif isinstance(parsed, str):
|
||
raw_path = parsed
|
||
else:
|
||
raw_path = str(parsed)
|
||
except:
|
||
raw_path = photo_source_str
|
||
|
||
if not raw_path:
|
||
return None
|
||
|
||
# 2. 剥离出最纯净的文件名 (只取最后一段)
|
||
pure_filename = raw_path.split('/')[-1]
|
||
|
||
# 3. 【终极物理净化】强行抠掉所有多余的标点符号!
|
||
# 哪怕传进来的是 123.jpg"] 或者是 "123.jpg",全部洗干净
|
||
pure_filename = pure_filename.replace('"', '').replace("'", "").replace('[', '').replace(']', '')
|
||
|
||
# 4. 拼接真实的 Docker 物理路径
|
||
file_path = os.path.join('/app/uploads', pure_filename)
|
||
|
||
# 5. 加入重试机制 (最多等 3 秒)
|
||
max_retries = 6
|
||
for i in range(max_retries):
|
||
if os.path.exists(file_path):
|
||
# 文件找到了,开始提取向量
|
||
vec = get_image_embedding(file_path)
|
||
if isinstance(vec, np.ndarray):
|
||
return vec.tolist()
|
||
return vec
|
||
else:
|
||
print(f"[AI 识图等待] 第 {i+1} 次尝试,未找到文件 {file_path},等待 0.5s...")
|
||
time.sleep(0.5)
|
||
|
||
print(f"[AI 识图警告] 彻底失败!经过等待依然未找到图片: {file_path}")
|
||
|
||
except Exception as e:
|
||
print(f"[AI 识图错误] 实时提取向量失败: {str(e)}")
|
||
|
||
return None
|