Files
micro_plastic/mask.py
2026-02-25 09:42:51 +08:00

1337 lines
45 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
微塑料检测完整流程 V2
包含:降采样 -> 滤纸分割(降采样) -> 上采样滤纸掩膜 -> 应用到原图像 -> Cellpose微塑料检测(原尺寸) -> 输出原尺寸掩膜
"""
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
from cellpose import models
from cellpose.io import imread, imsave
from scipy import ndimage
import warnings
import colorsys
from typing import List, Tuple
import os
warnings.filterwarnings('ignore')
def generate_colors(num_colors: int) -> List[Tuple[int, int, int]]:
"""
生成指定数量的随机颜色
Args:
num_colors: 需要生成的颜色数量
Returns:
颜色列表,每个颜色为(B, G, R)格式
"""
colors = []
np.random.seed(42) # 设置随机种子以确保结果可重现
# 生成随机HSV颜色空间的颜色
for i in range(num_colors):
hue = np.random.random() # 随机色调 (0-1)
saturation = np.random.uniform(0.6, 1.0) # 随机饱和度 (0.6-1.0)
value = np.random.uniform(0.7, 1.0) # 随机亮度 (0.7-1.0)
# 转换为RGB
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
# 转换为BGR格式OpenCV使用
bgr = (int(rgb[2] * 255), int(rgb[1] * 255), int(rgb[0] * 255))
colors.append(bgr)
return colors
def create_colored_mask(mask: np.ndarray, num_particles: int) -> np.ndarray:
"""
创建彩色掩膜,每个颗粒使用不同颜色
Args:
mask: 原始掩膜
num_particles: 颗粒数量
Returns:
彩色掩膜
"""
if num_particles == 0:
return np.zeros((*mask.shape, 3), dtype=np.uint8)
colors = generate_colors(num_particles)
colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
for i in range(1, num_particles + 1):
particle_mask = (mask == i)
colored_mask[particle_mask] = colors[i - 1]
return colored_mask
def downsample_image(image: np.ndarray, scale_factor: float = 0.25) -> Tuple[np.ndarray, Tuple[int, int]]:
"""
降采样图像
Args:
image: 输入图像
scale_factor: 缩放因子
Returns:
降采样后的图像和原始尺寸
"""
original_size = image.shape[:2]
new_width = int(image.shape[1] * scale_factor)
new_height = int(image.shape[0] * scale_factor)
downsampled = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return downsampled, original_size
def upsample_mask(mask: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
"""
上采样掩膜到目标尺寸
Args:
mask: 输入掩膜
target_size: 目标尺寸 (height, width)
Returns:
上采样后的掩膜
"""
return cv2.resize(mask, (target_size[1], target_size[0]), interpolation=cv2.INTER_NEAREST)
def detect_filter_paper_downsampled(image: np.ndarray, method: str = 'shape') -> np.ndarray:
"""
在降采样图像上检测滤纸区域,支持多种检测方法
Args:
image: 降采样后的图像
method: 检测方法 ('shape', 'color', 'edge', 'hsv', 'otsu', 'hough', 'threshold', 'combined')
- 'shape': 基于边缘和形状的方法(默认)
- 'color': 基于颜色阈值的办法
- 'edge': 基于边缘检测
- 'hsv': 基于HSV颜色空间
- 'otsu': 基于大律法自动阈值分割
- 'hough': 基于改进型霍夫圆变换
- 'threshold': 基于阈值检测的新方法(推荐用于有复杂背景的图像)
- 'combined': 组合方法(最鲁棒)
Returns:
滤纸掩膜
"""
if method == 'shape':
return detect_filter_by_shape(image)
elif method == 'color':
return detect_filter_by_color(image)
elif method == 'edge':
return detect_filter_by_edge(image)
elif method == 'hsv':
return detect_filter_by_hsv(image)
elif method == 'otsu':
return detect_filter_by_otsu(image)
elif method == 'hough':
return detect_filter_by_hough(image)
elif method == 'threshold':
return detect_filter_by_threshold(image)
elif method == 'combined':
return detect_filter_combined(image)
else:
raise ValueError(f"Unknown method: {method}")
def detect_filter_by_threshold(image: np.ndarray) -> np.ndarray:
"""
基于阈值检测的滤纸分割方法
该方法流程:
1. 使用自适应阈值或全局阈值进行分割
2. 获取最大连通域
3. 填充孔洞
4. 去除边缘多余像素
Args:
image: 输入图像
Returns:
滤纸掩膜
"""
# 转换为灰度图
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 高斯模糊减少噪声
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# 方法1尝试Otsu自动阈值
_, mask_otsu = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# 方法2尝试自适应阈值
mask_adaptive = cv2.adaptiveThreshold(
blurred,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
11,
2
)
# 选择最合适的方法(选择最大的连通域)
best_mask = None
max_area = 0
for mask in [mask_otsu, mask_adaptive]:
# 形态学操作去除噪声
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
mask_cleaned = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
mask_cleaned = cv2.morphologyEx(mask_cleaned, cv2.MORPH_CLOSE, kernel, iterations=2)
# 获取最大连通域
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_cleaned, connectivity=8)
if num_labels > 1:
areas = stats[1:, cv2.CC_STAT_AREA]
largest_component = np.argmax(areas) + 1
largest_mask = (labels == largest_component).astype(np.uint8) * 255
# 使用连通域统计的面积(更可靠)
largest_area = areas[np.argmax(areas)]
if largest_area > max_area:
max_area = largest_area
best_mask = largest_mask
if best_mask is None:
best_mask = mask_otsu
# 孔洞填充
mask_filled = ndimage.binary_fill_holes(best_mask).astype(np.uint8) * 255
# 去除边缘多余像素
# 方法:计算凸包,然后用凸包填充
contours, _ = cv2.findContours(mask_filled, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
largest_contour = max(contours, key=cv2.contourArea)
hull = cv2.convexHull(largest_contour)
# 创建凸包掩膜
mask_hull = np.zeros_like(mask_filled)
cv2.fillPoly(mask_hull, [hull], 255)
# 最后做一次平滑
kernel_smooth = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask_final = cv2.morphologyEx(mask_hull, cv2.MORPH_CLOSE, kernel_smooth, iterations=1)
return mask_final
return mask_filled
def detect_filter_by_shape(image: np.ndarray) -> np.ndarray:
"""
基于边缘和形状检测滤纸
Args:
image: 输入图像
Returns:
滤纸掩膜(圆形)
"""
# 转换为灰度图
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 高斯模糊
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
# Canny边缘检测
edges = cv2.Canny(blurred, 30, 100)
# 形态学操作连接边缘
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
edges_closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel)
# 填充内部
mask = ndimage.binary_fill_holes(edges_closed).astype(np.uint8) * 255
# 保留最大的连通域
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
if num_labels > 1: # 如果有连通域
# 找到面积最大的连通域(排除背景)
areas = stats[1:, cv2.CC_STAT_AREA] # 跳过背景标签0
largest_component = np.argmax(areas) + 1 # +1因为跳过了背景
mask = (labels == largest_component).astype(np.uint8) * 255
# 形状优化:去除突起,使其更接近圆形
mask = optimize_circular_shape(mask)
return mask
def detect_filter_by_color(image: np.ndarray) -> np.ndarray:
"""
基于颜色阈值检测滤纸
Args:
image: 输入图像
Returns:
滤纸掩膜
"""
# 转换为BGR格式OpenCV默认
bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# 转换为灰度图
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
# 自适应阈值(滤纸通常比背景亮)
# 尝试多种阈值方法
ret1, thresh1 = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
thresh2 = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2)
# 选择更亮的区域(假设滤纸较亮)
mask = cv2.bitwise_and(thresh1, 255 - thresh2)
# 形态学操作
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
# 填充孔洞
mask = ndimage.binary_fill_holes(mask).astype(np.uint8) * 255
# 保留最大连通域
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
if num_labels > 1:
areas = stats[1:, cv2.CC_STAT_AREA]
largest = np.argmax(areas) + 1
mask = (labels == largest).astype(np.uint8) * 255
# 圆形优化
mask = optimize_circular_shape(mask)
return mask
def detect_filter_by_edge(image: np.ndarray) -> np.ndarray:
"""
基于边缘检测滤纸(改进版)
Args:
image: 输入图像
Returns:
滤纸掩膜
"""
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 多层次边缘检测
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# 使用多个阈值检测边缘
edges1 = cv2.Canny(blurred, 50, 150)
edges2 = cv2.Canny(blurred, 100, 200)
edges3 = cv2.Canny(blurred, 30, 90)
# 合并边缘
edges = cv2.bitwise_or(cv2.bitwise_or(edges1, edges2), edges3)
# 形态学操作
kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25, 25))
kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
edges_closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel_large)
# 填充内部
mask = ndimage.binary_fill_holes(edges_closed).astype(np.uint8) * 255
# 去除小突起
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel_small, iterations=2)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_small, iterations=2)
# 保留最大连通域
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
if num_labels > 1:
areas = stats[1:, cv2.CC_STAT_AREA]
largest = np.argmax(areas) + 1
mask = (labels == largest).astype(np.uint8) * 255
return mask
def detect_filter_by_hsv(image: np.ndarray) -> np.ndarray:
"""
基于HSV颜色空间检测滤纸
Args:
image: 输入图像
Returns:
滤纸掩膜
"""
# 转换为HSV
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
# 滤纸通常在HSV空间中较亮、饱和度较低
# 定义HSV范围白色/浅色滤纸)
lower = np.array([0, 0, 120])
upper = np.array([180, 30, 255])
# 创建掩膜
mask = cv2.inRange(hsv, lower, upper)
# 也可以尝试检测特定颜色(如果有特定颜色的滤纸)
# 例如:红色滤纸
# lower_red1 = np.array([0, 50, 50])
# upper_red1 = np.array([10, 255, 255])
# lower_red2 = np.array([170, 50, 50])
# upper_red2 = np.array([180, 255, 255])
# mask_red = cv2.bitwise_or(cv2.inRange(hsv, lower_red1, upper_red1),
# cv2.inRange(hsv, lower_red2, upper_red2))
# 形态学操作
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=3)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
# 填充孔洞
mask = ndimage.binary_fill_holes(mask).astype(np.uint8) * 255
# 保留最大连通域
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
if num_labels > 1:
areas = stats[1:, cv2.CC_STAT_AREA]
largest = np.argmax(areas) + 1
mask = (labels == largest).astype(np.uint8) * 255
# 圆形优化
mask = optimize_circular_shape(mask)
return mask
def detect_filter_by_otsu(image: np.ndarray) -> np.ndarray:
"""
基于大律法Otsu自动阈值分割检测滤纸
大律法Otsu's Method是一种自动选择最优阈值的图像分割方法
它通过最大化类间方差来找到最佳的二值化阈值。
Args:
image: 输入图像
Returns:
滤纸掩膜
"""
# 转换为灰度图
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 应用高斯模糊减少噪声
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# 方法1标准大律法双峰灰度分布效果好
_, mask1 = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# 方法2反向大律法如果滤纸较暗
_, mask2 = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# 方法3大津法 + TRIANGLE算法作为备选
_, mask3 = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_TRIANGLE)
# 尝试每一种掩膜,选择有最大连通域的结果
masks_to_try = [mask1, mask2, mask3]
best_mask = None
max_area = 0
for mask in masks_to_try:
# 形态学操作
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
mask_processed = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
mask_processed = cv2.morphologyEx(mask_processed, cv2.MORPH_OPEN, kernel, iterations=1)
# 填充孔洞
mask_filled = ndimage.binary_fill_holes(mask_processed).astype(np.uint8) * 255
# 计算最大连通域面积
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_filled, connectivity=8)
if num_labels > 1:
areas = stats[1:, cv2.CC_STAT_AREA]
largest_area = np.max(areas) if len(areas) > 0 else 0
# 选择面积最大的
if largest_area > max_area:
max_area = largest_area
best_mask = mask_filled
if best_mask is None:
best_mask = mask1 # 默认使用第一个
# 保留最大连通域
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(best_mask, connectivity=8)
if num_labels > 1:
areas = stats[1:, cv2.CC_STAT_AREA]
largest = np.argmax(areas) + 1
best_mask = (labels == largest).astype(np.uint8) * 255
# 圆形优化
mask = optimize_circular_shape(best_mask)
return mask
def refine_circle_radius(image: np.ndarray, center: Tuple[int, int], radius: int) -> int:
"""
根据边缘像素精确调整圆半径
通过分析圆形边缘的梯度强度,动态调整半径以匹配实际的滤纸边界
Args:
image: 输入图像
center: 圆心坐标 (x, y)
radius: 初始半径
Returns:
优化后的半径
"""
x, y = center
h, w = image.shape[:2]
# 检查圆心是否在图像范围内
if x < 0 or x >= w or y < 0 or y >= h:
return radius
# 转换为灰度图
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image
# 应用高斯模糊
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Canny边缘检测
edges = cv2.Canny(blurred, 50, 150)
# 搜索最优半径
best_radius = radius
max_edge_pixels = 0
# 在原始半径的90%-105%范围内搜索(减少向下搜索的范围)
search_range = np.arange(int(radius * 0.90), int(radius * 1.05), 2)
for test_radius in search_range:
if test_radius <= 0:
continue
# 创建圆形掩膜
y_coords, x_coords = np.ogrid[:h, :w]
mask = ((x_coords - x)**2 + (y_coords - y)**2) <= test_radius**2
mask_prev = ((x_coords - x)**2 + (y_coords - y)**2) <= (test_radius - 2)**2
# 计算边缘像素数(在当前圆上但不在内层圆上的像素)
edge_mask = mask & (~mask_prev if test_radius > 2 else mask)
edge_pixels = np.sum(edges[edge_mask] > 0)
# 选择边缘像素数最多的半径
if edge_pixels > max_edge_pixels:
max_edge_pixels = edge_pixels
best_radius = test_radius
return best_radius
def contract_circle_mask(image: np.ndarray, mask: np.ndarray, shrink_percent: float = 0.05) -> np.ndarray:
"""
收缩圆形掩膜,去除边缘的背景像素
通过分析圆形边缘区域的像素特征,智能收缩掩膜以去除背景
Args:
image: 原始图像
mask: 初始掩膜
shrink_percent: 收缩百分比0.0-1.0默认5%
Returns:
收缩后的掩膜
"""
if mask.sum() == 0:
return mask
# 转换为灰度图
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image
# 获取掩膜的边界框和中心
y_coords, x_coords = np.where(mask > 0)
if len(y_coords) == 0:
return mask
# 计算质心(作为圆心)
center_x = int(np.mean(x_coords))
center_y = int(np.mean(y_coords))
# 计算当前半径(使用最大距离)
distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2)
current_radius = np.max(distances)
# 计算收缩后的半径
contracted_radius = int(current_radius * (1 - shrink_percent))
# 方法1简单收缩 - 直接缩小半径
h, w = mask.shape
y, x = np.ogrid[:h, :w]
contracted_mask = ((x - center_x)**2 + (y - center_y)**2) <= contracted_radius**2
contracted_mask = contracted_mask.astype(np.uint8) * 255
# 方法2基于像素亮度差异的智能收缩
# 计算边缘区域的亮度特征
edge_mask = ((x - center_x)**2 + (y - center_y)**2) > (contracted_radius * 0.9)**2
edge_mask = edge_mask & (mask > 0)
# 计算滤纸区域内部的平均亮度
inner_mask = ((x - center_x)**2 + (y - center_y)**2) <= (contracted_radius * 0.7)**2
inner_brightness = np.mean(gray[inner_mask]) if inner_mask.sum() > 0 else np.mean(gray)
# 计算边缘区域的亮度
edge_brightness = np.mean(gray[edge_mask]) if edge_mask.sum() > 0 else 0
# 如果边缘区域明显变暗(可能是背景),进一步收缩
# 提高阈值使收缩更保守从80%提高到70%
if edge_brightness < inner_brightness * 0.70: # 边缘比内部暗30%以上才收缩
additional_shrink = int(current_radius * 0.01) # 额外收缩1%(减少额外收缩)
contracted_radius = max(contracted_radius - additional_shrink, int(current_radius * 0.90))
# 重新生成收缩掩膜
contracted_mask = ((x - center_x)**2 + (y - center_y)**2) <= contracted_radius**2
contracted_mask = contracted_mask.astype(np.uint8) * 255
return contracted_mask
def detect_filter_by_hough(image: np.ndarray) -> np.ndarray:
"""
基于改进型霍夫圆变换检测圆形滤纸(优化版)
霍夫圆变换Hough Circle Transform是一种经典的圆形检测方法。
改进型霍夫圆变换使用渐进式概率霍夫变换HoughCircles能够自动检测图像中的圆形。
【优化版】针对含背景问题的改进:
1. 半径精调refine_circle_radius
- 在原始半径的90%-105%范围内搜索(保守搜索)
- 通过Canny边缘检测找到边缘像素最多的位置
- 动态调整半径匹配实际滤纸边界
2. 掩膜收缩contract_circle_mask
- 默认收缩2%去除边缘背景(减少裁剪)
- 基于亮度差异智能收缩边缘暗30%以上才额外收缩
- 额外收缩最多1%(保守策略)
- 最少保留90%的原始半径
Args:
image: 输入图像
Returns:
优化后的滤纸掩膜(圆形,已去除背景)
"""
# 转换为灰度图
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 应用高斯模糊减少噪声
blurred = cv2.GaussianBlur(gray, (9, 9), 2)
# 图像尺寸
height, width = gray.shape
min_radius = min(height, width) // 4 # 最小半径图像尺寸的1/4
max_radius = min(height, width) // 2 # 最大半径图像尺寸的1/2
# 使用改进型霍夫圆变换
# HOUGH_GRADIENT_ALT 使用替代的2-1霍夫变换适合检测大圆
circles = cv2.HoughCircles(
blurred,
cv2.HOUGH_GRADIENT_ALT, # 使用改进型算法
dp=1.5, # 累加器分辨率的反比值(越大检测越快但精度越低)
minDist=min(height, width), # 最小圆心间距
param1=50, # Canny边缘检测的高阈值
param2=0.8, # 累加器阈值(越小检测越多圆)
minRadius=min_radius, # 最小圆半径
maxRadius=max_radius # 最大圆半径
)
# 创建掩膜
mask = np.zeros_like(gray)
# 如果检测到圆
if circles is not None:
circles = np.round(circles[0, :]).astype("int")
# 获取最大半径的圆(通常是滤纸)
max_circle_idx = np.argmax(circles[:, 2]) # 半径在第三列
x, y, r = circles[max_circle_idx]
# 【优化1】根据边缘像素精确调整半径
optimized_r = refine_circle_radius(image, (x, y), r)
# 绘制圆形掩膜
cv2.circle(mask, (x, y), optimized_r, 255, -1)
# 形态学优化(平滑边缘)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1)
else:
# 如果霍夫变换未检测到圆,使用备选策略
# 使用标准的 Canny + HoughLinesP 检测边缘,然后拟合圆
# 方法1使用椭圆拟合如果滤纸可能是椭圆形
edges = cv2.Canny(blurred, 50, 150)
# 查找轮廓
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
# 找到最大的轮廓
largest_contour = max(contours, key=cv2.contourArea)
# 计算外接圆
(cx, cy), radius = cv2.minEnclosingCircle(largest_contour)
# 【优化2】对于轮廓方法也进行半径优化
optimized_radius = refine_circle_radius(image, (int(cx), int(cy)), int(radius))
# 绘制圆形掩膜
cv2.circle(mask, (int(cx), int(cy)), optimized_radius, 255, -1)
# 形态学优化
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1)
else:
# 如果都不行,使用形状方法作为最后备选
return detect_filter_by_shape(image)
# 【优化3】半径收缩优化 - 去除边缘的背景像素
# 减少收缩比例以避免过度裁剪从5%减少到2%
mask = contract_circle_mask(image, mask, shrink_percent=0.02)
return mask
def detect_filter_combined(image: np.ndarray) -> np.ndarray:
"""
组合多种方法检测滤纸(最鲁棒)
Args:
image: 输入图像
Returns:
滤纸掩膜
"""
# 使用多种方法检测
mask1 = detect_filter_by_shape(image)
mask2 = detect_filter_by_hsv(image)
mask3 = detect_filter_by_edge(image)
# 组合结果(多数投票)
mask_combined = np.zeros_like(mask1)
mask_combined = ((mask1 > 0).astype(int) +
(mask2 > 0).astype(int) +
(mask3 > 0).astype(int) >= 2)
mask_combined = mask_combined.astype(np.uint8) * 255
# 形态学操作
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
mask_combined = cv2.morphologyEx(mask_combined, cv2.MORPH_CLOSE, kernel, iterations=2)
mask_combined = cv2.morphologyEx(mask_combined, cv2.MORPH_OPEN, kernel, iterations=1)
# 填充孔洞
mask_combined = ndimage.binary_fill_holes(mask_combined).astype(np.uint8) * 255
# 保留最大连通域
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_combined, connectivity=8)
if num_labels > 1:
areas = stats[1:, cv2.CC_STAT_AREA]
largest = np.argmax(areas) + 1
mask_combined = (labels == largest).astype(np.uint8) * 255
# 最终圆形优化
mask_combined = optimize_circular_shape(mask_combined)
return mask_combined
def optimize_circular_shape(mask: np.ndarray) -> np.ndarray:
"""
优化掩膜形状,使其更接近圆形
Args:
mask: 输入掩膜
Returns:
优化后的圆形掩膜
"""
# 1. 多次开运算去除小突起
kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask_cleaned = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel_small, iterations=2)
# 2. 闭运算填充小孔洞
mask_filled = cv2.morphologyEx(mask_cleaned, cv2.MORPH_CLOSE, kernel_small, iterations=2)
# 3. 找到轮廓
contours, _ = cv2.findContours(mask_filled, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return mask_filled
# 4. 选择最大轮廓
largest_contour = max(contours, key=cv2.contourArea)
# 5. 计算轮廓的凸包
hull = cv2.convexHull(largest_contour)
# 6. 计算轮廓的圆形度
area = cv2.contourArea(largest_contour)
perimeter = cv2.arcLength(largest_contour, True)
circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
# 7. 如果圆形度较低,使用凸包
if circularity < 0.7: # 阈值可调整
# 创建凸包掩膜
hull_mask = np.zeros_like(mask)
cv2.fillPoly(hull_mask, [hull], 255)
# 8. 进一步优化:拟合圆形
mask_optimized = fit_circle_to_mask(hull_mask)
else:
# 圆形度较高,直接使用原轮廓
mask_optimized = np.zeros_like(mask)
cv2.fillPoly(mask_optimized, [largest_contour], 255)
return mask_optimized
def fit_circle_to_mask(mask: np.ndarray) -> np.ndarray:
"""
将掩膜拟合为圆形
Args:
mask: 输入掩膜
Returns:
圆形掩膜
"""
# 找到掩膜中的非零像素
points = np.column_stack(np.where(mask > 0))
if len(points) < 3:
return mask
# 使用最小二乘法拟合圆形
# 圆心和半径
center_x = np.mean(points[:, 1])
center_y = np.mean(points[:, 0])
# 计算半径(使用中位数减少异常值影响)
distances = np.sqrt((points[:, 1] - center_x)**2 + (points[:, 0] - center_y)**2)
radius = np.median(distances)
# 创建圆形掩膜
h, w = mask.shape
y, x = np.ogrid[:h, :w]
circle_mask = ((x - center_x)**2 + (y - center_y)**2) <= radius**2
circle_mask = circle_mask.astype(np.uint8) * 255
return circle_mask
def apply_filter_mask_to_original(original_image: np.ndarray, filter_mask: np.ndarray) -> np.ndarray:
"""
将滤纸掩膜应用到原图像上
Args:
original_image: 原始图像
filter_mask: 滤纸掩膜(已上采样到原尺寸)
Returns:
应用滤纸掩膜后的图像
"""
# 确保掩膜是二值的
filter_mask = (filter_mask > 0).astype(np.uint8) * 255
# 创建三通道掩膜
if len(original_image.shape) == 3:
mask_3ch = cv2.cvtColor(filter_mask, cv2.COLOR_GRAY2BGR)
else:
mask_3ch = filter_mask
# 应用掩膜
masked_image = cv2.bitwise_and(original_image, mask_3ch)
return masked_image
class MicroplasticDetectorV2:
"""微塑料检测器 V2 - 在降采样图像上检测滤纸,在原尺寸图像上检测微塑料"""
def __init__(self, model_path: str = None, device: str = 'auto', filter_method: str = 'shape'):
"""
初始化检测器
Args:
model_path: Cellpose模型路径
device: 设备类型 ('cpu', 'cuda', 'auto')
filter_method: 滤纸检测方法 ('shape', 'color', 'edge', 'hsv', 'otsu', 'hough', 'threshold', 'combined')
- 'shape': 基于边缘和形状的方法(默认)
- 'color': 基于颜色阈值的方法
- 'edge': 基于边缘检测
- 'hsv': 基于HSV颜色空间
- 'otsu': 基于大律法自动阈值分割
- 'hough': 基于改进型霍夫圆变换
- 'threshold': 基于阈值检测的新方法(推荐用于有复杂背景的图像)
- 'combined': 组合方法(最鲁棒,推荐)
"""
self.device = device
self.model_path = model_path
self.model = None
self.scale_factor = 0.25 # 降采样因子
self.filter_method = filter_method
def load_model(self):
"""加载Cellpose模型"""
if self.model is None:
if self.model_path and Path(self.model_path).exists():
# 加载自定义模型 - Cellpose 4.0 API
self.model = models.CellposeModel(
gpu=torch.cuda.is_available() and self.device != 'cpu',
model_type=None
)
self.model.load_model(self.model_path)
else:
# 使用预训练模型 - Cellpose 4.0 API
self.model = models.CellposeModel(
gpu=torch.cuda.is_available() and self.device != 'cpu',
model_type='cpsam'
)
print(f"模型已加载,使用设备: {'GPU' if torch.cuda.is_available() and self.device != 'cpu' else 'CPU'}")
def detect_microplastics(self, image_path: str, output_dir: str = None,
diameter: float = 30, flow_threshold: float = 0.4,
cellprob_threshold: float = 0.0, debug: bool = False) -> dict:
"""
检测微塑料颗粒
Args:
image_path: 输入图像路径
output_dir: 输出目录
diameter: 颗粒直径
flow_threshold: 流阈值
cellprob_threshold: 细胞概率阈值
debug: 调试模式True时输出所有中间结果和详细信息
Returns:
检测结果字典
"""
# 加载模型
self.load_model()
# 读取图像
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法读取图像: {image_path}")
if debug:
print(f"原始图像尺寸: {image.shape}")
# 1. 降采样图像
downsampled_image, original_size = downsample_image(image, self.scale_factor)
if debug:
print(f"降采样图像尺寸: {downsampled_image.shape}")
# 2. 在降采样图像上检测滤纸
if debug:
print(f"在降采样图像上检测滤纸 (方法: {self.filter_method})...")
filter_mask_downsampled = detect_filter_paper_downsampled(downsampled_image, method=self.filter_method)
# 3. 将滤纸掩膜上采样到原尺寸
if debug:
print("将滤纸掩膜上采样到原尺寸...")
filter_mask_original = upsample_mask(filter_mask_downsampled, original_size)
# 4. 将滤纸掩膜应用到原图像
if debug:
print("将滤纸掩膜应用到原图像...")
masked_image = apply_filter_mask_to_original(image, filter_mask_original)
# 5. 在原尺寸图像上运行Cellpose
if debug:
print("在原尺寸图像上运行Cellpose检测...")
# Cellpose 4.0 API调用
masks, flows, styles = self.model.eval(
masked_image,
diameter=diameter,
flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold)
# 6. 分析检测结果
if debug:
print("分析检测结果...")
particles_info = self._analyze_particles(masks)
num_microplastics = len(particles_info) if particles_info else 0
if debug:
print(f"检测到 {num_microplastics} 个微塑料颗粒")
# 7. 保存结果
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 获取输入文件名(不含扩展名)作为输出文件前缀
input_file_stem = Path(image_path).stem
# 始终保存微塑料掩膜(主要输出)
# 将 uint32 掩膜转换为 uint16 以支持 OpenCV imwrite
mask_to_save = masks.astype(np.uint16)
if debug:
cv2.imwrite(str(output_path / f"{input_file_stem}_microplastic_mask.png"), mask_to_save)
else:
cv2.imwrite(str(output_path / "microplastic_mask.png"), mask_to_save)
# 调试模式下保存所有中间结果
if debug:
# 保存原始掩膜
cv2.imwrite(str(output_path / f"{input_file_stem}_microplastic_mask_original.png"), mask_to_save)
# 保存彩色可视化
if num_microplastics > 0:
colored_mask = create_colored_mask(masks, num_microplastics)
cv2.imwrite(str(output_path / f"{input_file_stem}_microplastic_mask_visual.png"), colored_mask)
# 保存滤纸掩膜
cv2.imwrite(str(output_path / f"{input_file_stem}_filter_mask.png"), filter_mask_original)
# 保存应用滤纸掩膜后的图像
cv2.imwrite(str(output_path / f"{input_file_stem}_masked_image.png"), masked_image)
# 保存降采样图像和滤纸掩膜
cv2.imwrite(str(output_path / f"{input_file_stem}_downsampled_image.png"), downsampled_image)
cv2.imwrite(str(output_path / f"{input_file_stem}_filter_mask_downsampled.png"), filter_mask_downsampled)
print(f"调试结果已保存到: {output_path}")
else:
print(f"微塑料掩膜已保存到: {output_path / 'microplastic_mask.png'}")
return {
'masks': masks,
'particles_info': particles_info,
'num_microplastics': num_microplastics,
'filter_mask': filter_mask_original,
'masked_image': masked_image,
'flows': flows,
'styles': styles
}
def _analyze_particles(self, masks: np.ndarray) -> List[dict]:
"""
分析检测到的颗粒
Args:
masks: Cellpose输出的掩膜
Returns:
颗粒信息列表
"""
particles_info = []
if masks is None or masks.max() == 0:
return particles_info
# 获取每个颗粒的属性
for particle_id in range(1, masks.max() + 1):
particle_mask = (masks == particle_id)
if np.sum(particle_mask) == 0:
continue
# 计算颗粒属性
y_coords, x_coords = np.where(particle_mask)
# 中心坐标
center_y = np.mean(y_coords)
center_x = np.mean(x_coords)
# 边界框
min_y, max_y = np.min(y_coords), np.max(y_coords)
min_x, max_x = np.min(x_coords), np.max(x_coords)
# 面积
area = np.sum(particle_mask)
# 等效直径
equivalent_diameter = 2 * np.sqrt(area / np.pi)
particles_info.append({
'id': particle_id,
'center': (int(center_x), int(center_y)),
'bbox': (min_x, min_y, max_x, max_y),
'area': int(area),
'equivalent_diameter': equivalent_diameter
})
return particles_info
def predict_mask(self, image_path: str, diameter: float = 30,
flow_threshold: float = 0.4, cellprob_threshold: float = 0.0) -> np.ndarray:
"""
预测微塑料掩膜(生产模式,不保存文件)
Args:
image_path: 输入图像路径
diameter: 颗粒直径
flow_threshold: 流阈值
cellprob_threshold: 细胞概率阈值
Returns:
微塑料掩膜数组
"""
results = self.detect_microplastics(
image_path=image_path,
output_dir=None, # 不保存文件
diameter=diameter,
flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold,
debug=False # 关闭调试模式
)
return results['masks']
def detect_microplastic_mask(image_path: str, filter_method: str = 'shape',
diameter: float = None, flow_threshold: float = 0.4,
cellprob_threshold: float = 0.0, model_path: str = None,
device: str = 'auto') -> np.ndarray:
"""
检测图像中的微塑料颗粒并返回掩膜(简单调用接口)
Args:
image_path: 输入图像路径
filter_method: 滤纸检测方法 ('shape', 'color', 'edge', 'hsv', 'otsu', 'hough', 'threshold', 'combined')
默认: 'shape'
diameter: 颗粒直径None表示自动检测
flow_threshold: 流阈值,默认 0.4
cellprob_threshold: 细胞概率阈值,默认 0.0
model_path: Cellpose模型路径None表示使用预训练模型
device: 设备类型 ('cpu', 'cuda', 'auto')
Returns:
微塑料掩膜数组 (numpy array),每个颗粒用不同的整数标识
Example:
>>> mask = detect_microplastic_mask('image.png')
>>> # mask 是微塑料掩膜每个像素的值表示所属的颗粒ID
>>> # mask.max() 是颗粒总数
"""
detector = MicroplasticDetectorV2(
model_path=model_path,
device=device,
filter_method=filter_method
)
return detector.predict_mask(
image_path=image_path,
diameter=diameter,
flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold
)
def detect_microplastic_mask_from_array(image, filter_method: str = 'shape',
diameter: float = None, flow_threshold: float = 0.4,
cellprob_threshold: float = 0.0, model_path: str = None,
device: str = 'auto', model: models.CellposeModel = None,
detect_filter: bool = True) -> np.ndarray:
"""
从图像路径、numpy数组或PIL图像检测微塑料颗粒并返回掩膜
Args:
image: 输入图像,可以是:
- 文件路径字符串 (str)
- numpy array (BGR格式cv2.imread的返回值)
- PIL.Image对象 (RGB格式)
filter_method: 滤纸检测方法 ('shape', 'color', 'edge', 'hsv', 'otsu', 'hough', 'combined')
默认: 'shape'
diameter: 颗粒直径None表示自动检测
flow_threshold: 流阈值,默认 0.4
cellprob_threshold: 细胞概率阈值,默认 0.0
model_path: Cellpose模型路径None表示使用预训练模型
device: 设备类型 ('cpu', 'cuda', 'auto')
model: 已加载的Cellpose模型可选用于避免重复加载
detect_filter: 是否进行滤纸检测并去除True表示检测并去除滤纸False表示直接运行cellpose
默认: True
Returns:
微塑料掩膜数组 (numpy array),每个颗粒用不同的整数标识
Example:
>>> # 从文件路径加载
>>> mask = detect_microplastic_mask_from_array('image.png')
>>> # 从numpy数组
>>> import cv2
>>> image = cv2.imread('image.png')
>>> mask = detect_microplastic_mask_from_array(image)
>>> # 从PIL Image
>>> from PIL import Image
>>> pil_image = Image.open('image.png')
>>> mask = detect_microplastic_mask_from_array(pil_image)
>>> # 不进行滤纸检测直接运行cellpose
>>> mask, filter_mask = detect_microplastic_mask_from_array(image, detect_filter=False)
"""
# 处理字符串路径
if isinstance(image, str):
if not os.path.exists(image):
raise FileNotFoundError(f"图像文件不存在: {image}")
image = cv2.imread(image)
if image is None:
raise ValueError(f"无法读取图像文件: {image}")
# 处理PIL图像
if hasattr(image, 'mode'): # 检查是否是PIL Image
from PIL import Image
# PIL Image -> numpy array (RGB)
image_rgb = np.array(image)
# 转换为BGR格式OpenCV格式
image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
# 初始化检测器
detector = MicroplasticDetectorV2(
model_path=model_path,
device=device,
filter_method=filter_method
)
# 如果提供了预加载的模型,直接使用
if model is not None:
detector.model = model
# 加载模型
detector.load_model()
# 根据 detect_filter 参数决定是否进行滤纸检测
if detect_filter:
# 1. 降采样图像
downsampled_image, original_size = downsample_image(image, detector.scale_factor)
# 2. 在降采样图像上检测滤纸
filter_mask_downsampled = detect_filter_paper_downsampled(downsampled_image, method=detector.filter_method)
# 3. 将滤纸掩膜上采样到原尺寸
filter_mask_original = upsample_mask(filter_mask_downsampled, original_size)
# 4. 将滤纸掩膜应用到原图像
masked_image = apply_filter_mask_to_original(image, filter_mask_original)
else:
# 不进行滤纸检测,直接使用原图像
filter_mask_original = None
masked_image = image
# 5. 在原尺寸图像上运行Cellpose
masks, flows, styles = detector.model.eval(
masked_image,
diameter=diameter,
flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold
)
return masks, filter_mask_original
def main():
"""主函数 - 示例用法"""
# 检测微塑料
image_path = r"D:\Data\rgb_output\MPData7.png"# 替换为您的图像路径
output_dir = r"D:\Data\rgb_output\output"
# 示例1: 使用组合方法(最鲁棒,推荐)
print("=== 组合方法检测 (推荐) ===")
try:
detector_hough = MicroplasticDetectorV2(filter_method='threshold')
results_hough = detector_hough.detect_microplastics(
image_path=image_path,
output_dir=output_dir + "_threshold",
diameter=None,
flow_threshold=0.4,
cellprob_threshold=0,
debug=True
)
print(f"\n霍夫圆变换方法检测完成!")
print(f"检测到 {results_hough['num_microplastics']} 个微塑料颗粒")
except Exception as e:
print(f"霍夫圆变换方法检测过程中出现错误: {e}")
# # 示例2: 使用简单调用接口(从文件路径)
# print("\n=== 使用简单调用接口(从文件) ===")
# try:
# # 一行代码检测微塑料
# mask = detect_microplastic_mask(
# image_path=image_path,
# filter_method='combined', # 使用组合方法(最鲁棒)
# diameter=None,
# flow_threshold=0.4,
# cellprob_threshold=0.0
# )
# num_particles = mask.max()
# print(f"检测完成! 检测到 {num_particles} 个微塑料颗粒")
#
# # 保存掩膜
# if output_dir:
# output_path = Path(output_dir + "_simple")
# output_path.mkdir(parents=True, exist_ok=True)
# cv2.imwrite(str(output_path / "microplastic_mask.png"), mask)
# print(f"掩膜已保存到: {output_path / 'microplastic_mask.png'}")
#
# except Exception as e:
# print(f"简单调用接口出现错误: {e}")
#
# # 示例3: 从numpy数组调用避免重复读取图片
# print("\n=== 从numpy数组调用 ===")
# try:
# # 读取图片
# image = cv2.imread(image_path)
# if image is not None:
# # 直接传入图像数组
# mask = detect_microplastic_mask_from_array(
# image=image, # 直接传入cv2.imread的结果
# filter_method='combined',
# diameter=None,
# flow_threshold=0.4,
# cellprob_threshold=0.0
# )
# num_particles = mask.max()
# print(f"检测完成! 检测到 {num_particles} 个微塑料颗粒")
#
# # 保存掩膜
# if output_dir:
# output_path = Path(output_dir + "_array")
# output_path.mkdir(parents=True, exist_ok=True)
# cv2.imwrite(str(output_path / "microplastic_mask.png"), mask)
# print(f"掩膜已保存到: {output_path / 'microplastic_mask.png'}")
#
# except Exception as e:
# print(f"数组调用接口出现错误: {e}")
if __name__ == "__main__":
main()