Files
KCGL/inventory-backend/app/utils/ai_vision.py

280 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
AI Vision 模块 - CLIP Vision Encoder ONNX 推理
"""
import os
import json
import time
import numpy as np
from PIL import Image
import onnxruntime as ort
import cv2
# ============================================================================
# 全局模型单例(项目启动时加载一次)
# ============================================================================
MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'clip_vision.onnx')
# 加载选项CPU 推理,禁用依赖库的启动开销
_session_options = ort.SessionOptions()
_session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_session: ort.InferenceSession = None
def load_clip_model():
"""启动时调用:全局加载 CLIP Vision 模型"""
global ort_session
if ort_session is not None:
return ort_session
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"CLIP Vision 模型未找到: {MODEL_PATH}")
ort_session = ort.InferenceSession(MODEL_PATH, sess_options=_session_options, providers=['CPUExecutionProvider'])
print(f"✅ [AI Vision] CLIP 模型加载成功: {MODEL_PATH}")
return ort_session
# ============================================================================
# CLIP 预处理常量
# ============================================================================
# ImageNet 标准归一化CLIP 官方)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# 模型输入尺寸
INPUT_SIZE = 224
# ============================================================================
# 背景去除配置HSV 色彩空间阈值
# ============================================================================
# OpenCV HSV: H∈[0,180], S∈[0,255], V∈[0,255]
# 注意OpenCV 中 H 通道范围是 0-180是 OpenCV 自己的标准,和美术的 0-360 对应)
# 绿色背景阈值(工业绿幕常用色)
# H: 35~85 对应绿色谱(浅绿到深绿)
# S: 低饱和度35到高饱和度255
# V: 明暗均可30~255
BG_GREEN_LOWER = np.array([35, 35, 30])
BG_GREEN_UPPER = np.array([90, 255, 255])
# 白色/浅色背景阈值(高明度、低饱和度区域)
# H: 不限制0~180只看 S 和 V
# S: 很低的饱和度0~35→ 接近纯灰/白色
# V: 高明度180~255
BG_WHITE_LOWER = np.array([0, 0, 180])
BG_WHITE_UPPER = np.array([180, 40, 255])
# 中性灰填充色BGR → 转换后 RGB 也是 128,128,128
NEUTRAL_GRAY_BGR = (128, 128, 128)
def _remove_background(image: Image.Image) -> Image.Image:
"""
利用 OpenCV HSV 色彩空间识别并替换背景为中性灰
支持两种背景类型:
1. 工业绿幕/绿色背景H: 35~90
2. 白色/浅色背景(高亮度、低饱和度)
逻辑:
- 将 PIL Image 转为 OpenCV 格式 (RGB → BGR)
- 转 HSV分别生成绿色掩码和白色掩码
- 合并掩码后,按掩码将背景区域替换为中性灰
- 还原为 PIL Image (BGR → RGB) 返回
参数:
image: PIL Image (RGB, uint8)
返回:
处理后的 PIL Image (RGB, uint8)
"""
# PIL (RGB) → OpenCV (BGR)
img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# 转入 HSV 色彩空间
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
# 生成掩码 1绿色背景
mask_green = cv2.inRange(hsv, BG_GREEN_LOWER, BG_GREEN_UPPER)
# 生成掩码 2白色/浅色背景
mask_white = cv2.inRange(hsv, BG_WHITE_LOWER, BG_WHITE_UPPER)
# 合并掩码(任意一种背景都替换)
mask_combined = cv2.bitwise_or(mask_green, mask_white)
# 形态学处理:消除噪点(小面积背景噪点填平)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask_combined = cv2.morphologyEx(mask_combined, cv2.MORPH_CLOSE, kernel)
mask_combined = cv2.morphologyEx(mask_combined, cv2.MORPH_OPEN, kernel)
# 背景替换:将掩码区域填充为中性灰
# 其中 mask_combined=255 的区域为背景,替换为 NEUTRAL_GRAY_BGR
img_bgr_no_bg = img_bgr.copy()
img_bgr_no_bg[mask_combined > 0] = NEUTRAL_GRAY_BGR
# OpenCV (BGR) → PIL (RGB)
result = Image.fromarray(cv2.cvtColor(img_bgr_no_bg, cv2.COLOR_BGR2RGB))
return result
def _letterbox_image(image: Image.Image, size: int = 224) -> Image.Image:
"""
Letterbox 预处理:等比例缩放 + 灰色填充,保持内容不变形
- 将原图最长边缩放到 224
- 短边按相同比例缩放
- 不足部分用 RGB(128,128,128) 灰色填充至 224x224
参数:
image: PIL Image 对象
size: 目标尺寸,默认 224
返回:
224x224 PIL Image
"""
w, h = image.size
# 计算缩放比例,使最长边等于 size
scale = size / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
# 等比例缩放
resized = image.resize((new_w, new_h), Image.LANCZOS)
# 创建灰色画布
canvas = Image.new('RGB', (size, size), (128, 128, 128))
# 将缩放后的图片粘贴到画布正中央
paste_x = (size - new_w) // 2
paste_y = (size - new_h) // 2
canvas.paste(resized, (paste_x, paste_y))
return canvas
def _normalize(image_np: np.ndarray) -> np.ndarray:
"""
对 224x224x3 图像进行 CLIP 标准归一化
image_np: shape (H, W, C), dtype uint8, 值域 [0, 255]
返回: shape (C, H, W), dtype float32, 值域 [0, 1]
"""
# HWC -> CHW
image_np = image_np.transpose(2, 0, 1).astype(np.float32) / 255.0
# 归一化
for i, (mean, std) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
image_np[i] = (image_np[i] - mean) / std
return image_np
# ============================================================================
# 主函数:提取图像 embedding
# ============================================================================
def get_image_embedding(image_path: str) -> list:
"""
提取图像的 512 维 CLIP embedding 向量
参数:
image_path: 图像文件路径
返回:
list: 512 维浮点向量
"""
if ort_session is None:
load_clip_model()
# 1. 图片预处理
# Step 1: 背景去除HSV 色彩空间,绿色/白色背景 → 中性灰替换)
image = Image.open(image_path).convert('RGB')
image = _remove_background(image)
# Step 2: Letterbox 等比例缩放(保持内容不变形)
image = _letterbox_image(image, INPUT_SIZE)
input_data = _normalize(np.array(image))
input_data = np.expand_dims(input_data, axis=0) # [1, 3, 224, 224]
# 2. 构造占位符输入 (关键修复)
dummy_ids = np.zeros((1, 77), dtype=np.int64)
dummy_mask = np.zeros((1, 77), dtype=np.int64)
# 3. 传入模型进行推理
# 注意: 模型输入名在你的模型里必须叫 'pixel_values', 'input_ids', 'attention_mask'
# 如果报错找不到输入名,请打印 ort_session.get_inputs()[0].name 确认
outputs = ort_session.run(
['image_embeds'],
{
'input_ids': dummy_ids,
'pixel_values': input_data.astype(np.float32),
'attention_mask': dummy_mask
}
)
return outputs[0][0].tolist()
# ============================================================================
# 通用向量提取工具:防呆、防错
# ============================================================================
def extract_and_embed(photo_source):
if not photo_source:
return None
try:
# 1. 提取基础字符串
photo_source_str = str(photo_source).strip()
raw_path = ""
# 尝试剥掉 JSON 外壳
try:
parsed = json.loads(photo_source_str)
if isinstance(parsed, list):
raw_path = parsed[0] if parsed else ""
elif isinstance(parsed, str):
raw_path = parsed
else:
raw_path = str(parsed)
except:
raw_path = photo_source_str
if not raw_path:
return None
# 2. 剥离出最纯净的文件名 (只取最后一段)
pure_filename = raw_path.split('/')[-1]
# 3. 【终极物理净化】强行抠掉所有多余的标点符号!
# 哪怕传进来的是 123.jpg"] 或者是 "123.jpg",全部洗干净
pure_filename = pure_filename.replace('"', '').replace("'", "").replace('[', '').replace(']', '')
# 4. 拼接真实的 Docker 物理路径
file_path = os.path.join('/app/uploads', pure_filename)
# 5. 加入重试机制 (最多等 3 秒)
max_retries = 6
for i in range(max_retries):
if os.path.exists(file_path):
# 文件找到了,开始提取向量
vec = get_image_embedding(file_path)
if isinstance(vec, np.ndarray):
return vec.tolist()
return vec
else:
print(f"[AI 识图等待] 第 {i+1} 次尝试,未找到文件 {file_path},等待 0.5s...")
time.sleep(0.5)
print(f"[AI 识图警告] 彻底失败!经过等待依然未找到图片: {file_path}")
except Exception as e:
print(f"[AI 识图错误] 实时提取向量失败: {str(e)}")
return None