import numpy as np import pandas as pd import spectral import cv2 import os from scipy.io import savemat from scipy import ndimage from skimage import filters from typing import Optional, Dict, Any, Tuple, List, Union from dataclasses import dataclass, field import warnings warnings.filterwarnings('ignore') @dataclass class ThresholdSegmentationConfig: """阈值分割配置类""" # 输入文件配置 input_path: Optional[str] = None band_index: int = 0 # 指定的波段索引(从0开始) # 输出配置 output_path: Optional[str] = None output_dir: Optional[str] = None # 分割方法配置 method: str = 'otsu' threshold_value: Optional[float] = None # 固定阈值使用 # 批量处理配置 batch_methods: List[str] = field(default_factory=lambda: ['otsu']) batch_bands: List[int] = field(default_factory=lambda: [0]) # 自适应阈值参数 adaptive_block_size: int = 11 # 必须是奇数 adaptive_c: float = 2.0 adaptive_method: str = 'gaussian' # 'gaussian' 或 'mean' # 迭代法参数 max_iterations: int = 100 convergence_threshold: float = 0.01 # 直方图双峰法参数 histogram_bins: int = 256 peak_min_distance: int = 10 # Otsu参数 otsu_normalize: bool = True # 是否将图像归一化到0-255范围 def __post_init__(self): """参数校验和默认值设置""" # 校验必需的文件路径 if not self.input_path: raise ValueError("必须指定输入文件路径(input_path)") if not os.path.exists(self.input_path): raise FileNotFoundError(f"输入文件不存在: {self.input_path}") # 校验输出路径 if not self.output_path and not self.output_dir: raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)") # 校验分割方法 supported_methods = self._get_supported_methods() if self.method not in supported_methods: raise ValueError(f"不支持的分割方法: {self.method}。支持的方法: {list(supported_methods.keys())}") # 校验批量处理参数 if len(self.batch_methods) != len(self.batch_bands): # 如果长度不等,使用广播方式:单个波段对应多个方法,或单个方法对应多个波段 if len(self.batch_methods) == 1 and len(self.batch_bands) > 1: # 一个方法应用到多个波段 self.batch_methods = self.batch_methods * len(self.batch_bands) elif len(self.batch_bands) == 1 and len(self.batch_methods) > 1: # 一个波段应用多个方法 self.batch_bands = self.batch_bands * len(self.batch_methods) else: raise ValueError("batch_methods和batch_bands长度不匹配,请确保其中一个长度为1,或两者长度相等") for method in self.batch_methods: if method not in supported_methods: raise ValueError(f"批量处理中包含不支持的方法: {method}") # 校验自适应阈值参数 if self.adaptive_block_size % 2 == 0: raise ValueError("adaptive_block_size必须是奇数") if self.adaptive_method.lower() not in ['gaussian', 'mean']: raise ValueError("adaptive_method必须是'gaussian'或'mean'") # 校验直方图参数 if self.histogram_bins <= 0: raise ValueError("histogram_bins必须大于0") if self.peak_min_distance <= 0: raise ValueError("peak_min_distance必须大于0") # 校验迭代参数 if self.max_iterations <= 0: raise ValueError("max_iterations必须大于0") if self.convergence_threshold <= 0: raise ValueError("convergence_threshold必须大于0") def _get_supported_methods(self) -> Dict[str, str]: """获取支持的分割方法列表""" methods = { 'fixed': '固定阈值分割', 'histogram_bimodal': '直方图双峰法', 'adaptive': '自适应阈值分割', 'iterative': '迭代法阈值分割', 'otsu': 'Otsu大津法分割', 'isodata': 'ISODATA分割' } return methods class ThresholdSegmenter: """阈值分割处理系统""" def __init__(self, config: ThresholdSegmentationConfig): self.config = config # 初始化数据存储 self.data = None self.selected_band = None self.original_shape = None self.data_type = None # 'hyperspectral' or 'rgb' def load_data(self) -> None: """ 加载数据文件 """ print(f"正在加载数据文件: {self.config.input_path}") file_ext = os.path.splitext(self.config.input_path)[1].lower() if file_ext in ['.dat', '.img', '.hdr']: self._load_hyperspectral_data(self.config.input_path) self.data_type = 'hyperspectral' elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']: self._load_rgb_data(self.config.input_path) self.data_type = 'rgb' else: raise ValueError(f"不支持的文件格式: {file_ext}") # 提取指定波段 self._extract_band(self.config.band_index) print("数据加载完成") def _load_hyperspectral_data(self, file_path): """加载高光谱图像数据""" base_path = os.path.splitext(file_path)[0] hdr_file = base_path + '.hdr' # 读取ENVI格式图像 img = spectral.open_image(hdr_file) self.data = img.load() self.original_shape = self.data.shape # 转换为二维数组(像素×波段) self.data = self.data.reshape(-1, self.data.shape[2]) def _load_rgb_data(self, file_path): """加载RGB图像数据""" # 使用OpenCV读取RGB图像 img = cv2.imread(file_path, cv2.IMREAD_UNCHANGED) if img is None: raise ValueError(f"无法读取图像文件: {file_path}") # 如果是彩色图像,转换为灰度 if len(img.shape) == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) self.original_shape = img.shape self.data = img.reshape(-1, 1) # 单波段数据 def _extract_band(self, band_index): """提取指定波段""" if self.data_type == 'hyperspectral': if band_index >= self.data.shape[1]: raise ValueError(f"波段索引超出范围: {band_index}, 总波段数: {self.data.shape[1]}") self.selected_band = self.data[:, band_index].reshape(self.original_shape[0], self.original_shape[1]) else: # RGB数据 if band_index != 0: print(f"警告: RGB图像只有一个波段,忽略波段索引 {band_index}") self.selected_band = self.data.reshape(self.original_shape) def _preprocess_image_for_segmentation(self, image): """预处理图像以适应分割算法""" # 检测像素值范围 image_min, image_max = np.min(image), np.max(image) # 如果像素值在0-1范围,将其转换为0-255范围以提高分割精度 if image_max <= 1.0 and image_min >= 0.0: processed_image = (image * 255).astype(np.uint8) print(f"检测到0-1范围的像素值,已转换为0-255范围进行分割") return processed_image else: # 对于已经在0-255范围内的图像,直接使用 return image.astype(np.uint8) if image.dtype != np.uint8 else image def apply_segmentation(self) -> Tuple[np.ndarray, float]: """ 应用阈值分割方法 Returns: Tuple[np.ndarray, float]: (分割结果, 使用的阈值) """ print(f"正在应用分割方法: {self.config.method}") if self.selected_band is None: raise ValueError("请先加载数据") # 检测和预处理像素值范围 processed_band = self._preprocess_image_for_segmentation(self.selected_band) # 应用不同的分割方法 if self.config.method == 'fixed': threshold = self.config.threshold_value if threshold is None: # 如果未指定,使用图像中位数作为默认阈值 threshold = np.median(processed_band) print(f"使用默认阈值: {threshold}") result = self._fixed_threshold(processed_band, threshold) elif self.config.method == 'histogram_bimodal': result, threshold = self._histogram_bimodal_threshold(processed_band) elif self.config.method == 'adaptive': result, threshold = self._adaptive_threshold(processed_band) elif self.config.method == 'iterative': result, threshold = self._iterative_threshold(processed_band) elif self.config.method == 'otsu': result, threshold = self._otsu_threshold(processed_band) elif self.config.method == 'isodata': result, threshold = self._isodata_threshold(processed_band) else: raise ValueError(f"不支持的分割方法: {self.config.method}") return result.astype(np.uint8), threshold def _fixed_threshold(self, image, threshold): """固定阈值分割""" return (image > threshold).astype(np.uint8) def _histogram_bimodal_threshold(self, image): """直方图双峰法阈值分割""" # 计算直方图 hist, bins = np.histogram(image.flatten(), bins=self.config.histogram_bins, range=(np.min(image), np.max(image))) # 寻找直方图的两个峰值 peaks = self._find_histogram_peaks(hist) if len(peaks) < 2: # 如果找不到两个明显的峰,使用中位数 threshold = np.median(image) print("未找到明显的双峰,使用中位数作为阈值") else: # 使用两个峰值之间的谷作为阈值 threshold = (peaks[0] + peaks[1]) / 2 result = (image > threshold).astype(np.uint8) return result, threshold def _find_histogram_peaks(self, hist): """寻找直方图峰值""" peaks = [] for i in range(1, len(hist) - 1): if hist[i] > hist[i-1] and hist[i] > hist[i+1] and hist[i] > np.mean(hist) * 0.5: # 检查与之前峰的距离 if not peaks or abs(i - peaks[-1]) > self.config.peak_min_distance: peaks.append(i) elif hist[i] > hist[peaks[-1]]: peaks[-1] = i return peaks[:2] # 返回前两个峰值 def _adaptive_threshold(self, image): """自适应阈值分割""" # 根据配置选择自适应方法 if self.config.adaptive_method.lower() == 'gaussian': adaptive_type = cv2.ADAPTIVE_THRESH_GAUSSIAN_C elif self.config.adaptive_method.lower() == 'mean': adaptive_type = cv2.ADAPTIVE_THRESH_MEAN_C else: raise ValueError(f"不支持的自适应方法: {self.config.adaptive_method}") # 使用自适应阈值(图像已经在预处理中转换为合适范围) binary = cv2.adaptiveThreshold( image, 255, adaptive_type, cv2.THRESH_BINARY, self.config.adaptive_block_size, self.config.adaptive_c ) result = (binary / 255).astype(np.uint8) threshold = np.mean(image) / 255.0 # 返回0-1范围的阈值作为参考 return result, threshold def _iterative_threshold(self, image): """迭代法阈值分割""" # 初始化阈值 threshold = np.mean(image) prev_threshold = 0 for iteration in range(self.config.max_iterations): # 根据当前阈值分割图像 foreground = image[image > threshold] background = image[image <= threshold] if len(foreground) == 0 or len(background) == 0: break # 计算新的阈值 new_threshold = (np.mean(foreground) + np.mean(background)) / 2 # 检查收敛 if abs(new_threshold - threshold) < self.config.convergence_threshold: break threshold = new_threshold prev_threshold = threshold result = (image > threshold).astype(np.uint8) return result, threshold def _otsu_threshold(self, image): """Otsu大津法阈值分割""" # 使用OpenCV的Otsu方法(图像已经在预处理中转换为合适范围) threshold, binary = cv2.threshold( image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU ) result = (binary / 255).astype(np.uint8) return result, threshold def _isodata_threshold(self, image): """ISODATA阈值分割""" # 初始化阈值 threshold = np.mean(image) for iteration in range(self.config.max_iterations): # 根据当前阈值分割图像 foreground = image[image > threshold] background = image[image <= threshold] if len(foreground) == 0 or len(background) == 0: break # 计算均值 mu1 = np.mean(foreground) mu2 = np.mean(background) # 计算新的阈值 new_threshold = (mu1 + mu2) / 2 # 检查收敛 if abs(new_threshold - threshold) < self.config.convergence_threshold: break threshold = new_threshold result = (image > threshold).astype(np.uint8) return result, threshold def _generate_output_filename(self, method, band_idx=None): """ 生成输出文件名 参数: method: 分割方法 band_idx: 波段索引(可选) 返回: str: 生成的文件名 """ # 从输入路径提取原始文件名(不含扩展名) input_basename = os.path.splitext(os.path.basename(self.config.input_path))[0] # 使用波段索引(如果提供的话,否则使用配置中的) if band_idx is None: band_idx = getattr(self.config, 'band_index', None) # 根据数据类型生成文件名 if self.data_type == 'hyperspectral' and band_idx is not None: filename = f'{input_basename}_{method}_band{band_idx}.dat' else: filename = f'{input_basename}_{method}.dat' return filename def save_results(self, segmented_data, threshold, output_path, method): """ 保存分割结果 参数: segmented_data: 分割后的数据 threshold: 使用的阈值 output_path: 输出路径或输出目录 method: 使用的分割方法 """ # 如果output_path是目录,则基于原始文件名生成文件名 if os.path.isdir(output_path): filename = self._generate_output_filename(method) output_path = os.path.join(output_path, filename) output_base = os.path.splitext(output_path)[0] # 保存.dat文件 self._save_dat_file(segmented_data, output_path) # 保存.hdr头文件 hdr_path = output_base + '.hdr' self._save_hdr_file(hdr_path, segmented_data.shape, method, threshold) print(f"分割结果已保存: {output_path}") print(f"头文件已保存: {hdr_path}") print(f"使用的阈值: {threshold:.4f}") def _save_dat_file(self, data, file_path): """保存.dat文件(二进制格式)""" with open(file_path, 'wb') as f: data.astype(np.uint8).tofile(f) def _save_hdr_file(self, hdr_path, data_shape, method, threshold): """保存ENVI头文件""" header_content = f"""ENVI description = {{{method} Segmentation Result [Threshold: {threshold:.4f}]}} samples = {data_shape[1]} lines = {data_shape[0]} bands = 1 header offset = 0 file type = ENVI Standard data type = 1 interleave = bsq byte order = 0 band names = {{Segmented_Result}} classes = 2 class names = {{Background, Foreground}} class lookup = {{0,0,0, 255,255,255}} """ with open(hdr_path, 'w', encoding='utf-8') as f: f.write(header_content) def batch_process(self, methods, bands, output_dir): """ 批量处理多个分割方法 参数: methods: 分割方法列表 bands: 波段索引列表 output_dir: 输出目录 """ os.makedirs(output_dir, exist_ok=True) # 显示数据信息 if self.data is not None: print(f"\n=== 数据信息 ===") print(f"数据类型: {self.data_type}") if self.data_type == 'hyperspectral': print(f"图像尺寸: {self.original_shape[0]}x{self.original_shape[1]}") print(f"波段数: {self.original_shape[2]}") else: print(f"图像尺寸: {self.original_shape[0]}x{self.original_shape[1]}") results = {} for method, band_idx in zip(methods, bands): print(f"\n正在处理: {method} (波段: {band_idx})") try: # 临时修改配置以适应当前方法和波段 original_method = self.config.method original_band = self.config.band_index self.config.method = method self.config.band_index = band_idx # 重新提取波段 self._extract_band(band_idx) # 应用分割 segmented_data, threshold = self.apply_segmentation() # 恢复原始配置 self.config.method = original_method self.config.band_index = original_band # 保存结果 - 使用原始文件名加上方法名 filename = self._generate_output_filename(method, band_idx) output_path = os.path.join(output_dir, filename) self.save_results(segmented_data, threshold, output_path, method) results[method] = { 'data': segmented_data, 'threshold': threshold, 'band': band_idx } except Exception as e: print(f"处理 {method} 时出错: {str(e)}") return results @staticmethod def print_method_info(): """ 打印所有可用分割方法的详细信息 """ method_info = { 'fixed': { 'name': '固定阈值分割 (Fixed Threshold)', 'description': '使用预设的固定阈值进行分割', 'advantages': ['简单快速', '可控性强', '适合已知阈值的情况'], 'limitations': ['需要预知合适阈值', '对图像变化敏感'], 'best_for': ['对比度高的图像', '已知目标特征的分割'] }, 'histogram_bimodal': { 'name': '直方图双峰法 (Histogram Bimodal)', 'description': '基于图像直方图的双峰分布寻找最优阈值', 'advantages': ['自动寻找阈值', '适合双峰分布图像'], 'limitations': ['仅适用于双峰直方图', '对噪声敏感'], 'best_for': ['具有明显双峰分布的图像'] }, 'adaptive': { 'name': '自适应阈值分割 (Adaptive Threshold)', 'description': '根据图像局部区域的统计特性动态调整阈值', 'advantages': ['适应图像局部变化', '适合不均匀照明'], 'limitations': ['计算复杂度较高', '参数选择重要'], 'best_for': ['光照不均匀的图像', '复杂背景'] }, 'iterative': { 'name': '迭代法阈值分割 (Iterative Threshold)', 'description': '通过迭代优化找到前景和背景的最佳分离阈值', 'advantages': ['自动收敛', '理论基础扎实'], 'limitations': ['迭代计算', '可能收敛到局部最优'], 'best_for': ['一般图像分割', '前景背景区分明显'] }, 'otsu': { 'name': 'Otsu大津法 (Otsu Method)', 'description': '基于最大类间方差准则寻找最优阈值', 'advantages': ['理论严谨', '自动寻找最优阈值', '应用广泛'], 'limitations': ['假设双峰分布', '对噪声敏感'], 'best_for': ['大多数图像分割任务', '学术研究'] }, 'isodata': { 'name': 'ISODATA分割 (ISODATA)', 'description': '基于迭代自组织数据分析的阈值选择方法', 'advantages': ['适应性强', '自动聚类'], 'limitations': ['计算复杂度较高', '参数敏感'], 'best_for': ['复杂图像', '多峰分布数据'] } } print("\n=== 阈值分割方法详细说明 ===\n") for method, info in method_info.items(): print(f"{method} - {info['name']}") print(f" 描述: {info['description']}") print(f" 优势: {', '.join(info['advantages'])}") print(f" 局限性: {', '.join(info['limitations'])}") print(f" 适用场景: {', '.join(info['best_for'])}") print() # 使用示例 def main(): # 显示所有方法的详细信息 print("=== 阈值分割方法介绍 ===") ThresholdSegmenter.print_method_info() # 示例1: 处理高光谱图像 print("=== 处理高光谱图像数据示例 ===") try: # 创建分割处理器 hyperspectral_config = ThresholdSegmentationConfig( input_path=r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr", band_index=50, # 使用第51个波段 method='otsu', # 默认方法 output_dir='segmentation_results_hyperspectral' ) hyperspectral_segmenter = ThresholdSegmenter(hyperspectral_config) # 加载数据 hyperspectral_segmenter.load_data() # 批量应用多种分割方法 methods = ['otsu', 'isodata', 'adaptive', 'iterative', 'fixed', 'histogram_bimodal'] bands = [25, 25, 25, 25, 25, 25] # 对同一波段使用不同方法 results = hyperspectral_segmenter.batch_process(methods, bands, 'segmentation_results_hyperspectral') except FileNotFoundError: print("高光谱文件未找到,跳过高光谱示例...") # # 示例2: 处理RGB图像 # print("\n=== 处理RGB图像示例 ===") # try: # # 创建分割处理器 # rgb_config = ThresholdSegmentationConfig( # input_path=r"example_image.jpg", # 替换为实际的RGB图像路径 # band_index=0, # RGB图像只有一个波段 # method='otsu', # output_dir='segmentation_results_rgb' # ) # rgb_segmenter = ThresholdSegmenter(rgb_config) # # 加载数据 # rgb_segmenter.load_data() # # 应用分割 # methods = ['otsu', 'adaptive', 'iterative'] # bands = [0, 0, 0] # results = rgb_segmenter.batch_process(methods, bands, 'segmentation_results_rgb') # except FileNotFoundError: # print("RGB图像文件未找到,跳过RGB示例...") # print("\n处理完成!") if __name__ == "__main__": main()