import numpy as np import os import warnings from typing import Tuple, List, Dict, Optional, Union from pathlib import Path from dataclasses import dataclass import spectral as spy from spectral import envi import matplotlib.pyplot as plt from skimage import morphology from scipy import ndimage import struct warnings.filterwarnings('ignore') class HyperspectralMorphologyProcessor: """高光谱图像形态学处理类""" def __init__(self): self.data = None self.header = None self.wavelengths = None self.shape = None self.dtype = None def load_hyperspectral_image(self, file_path: str, format_type: str = None) -> Tuple[np.ndarray, Dict]: """ 加载高光谱图像(支持bil、bip、bsq、dat格式) Parameters: ----------- file_path : str 图像文件路径(可以是.hdr、.bil、.bip、.bsq、.dat) format_type : str, optional 格式类型('bil'、'bip'、'bsq'、'dat'),如果为None则自动检测 Returns: -------- data : np.ndarray 高光谱数据立方体 (rows, cols, bands) header : Dict 头文件信息 """ try: # 自动检测格式 if format_type is None: format_type = self._detect_format(file_path) # 如果是.hdr文件,直接使用spectral加载 if file_path.lower().endswith('.hdr'): img = envi.open(file_path) data = img.load() header = dict(img.metadata) # 如果是其他格式,尝试查找对应的.hdr文件 else: # 查找可能的hdr文件 hdr_candidates = [ file_path.rsplit('.', 1)[0] + '.hdr', file_path + '.hdr', file_path[:-4] + '.hdr' if len(file_path) > 4 else None ] hdr_file = None for candidate in hdr_candidates: if candidate and os.path.exists(candidate): hdr_file = candidate break if hdr_file is None: raise FileNotFoundError(f"未找到头文件(.hdr)用于 {file_path}") # 加载图像 img = envi.open(hdr_file) data = img.load() header = dict(img.metadata) # 根据格式调整数据布局 if format_type.lower() == 'bil': # BIL: 波段交错按行 data = self._convert_to_bil(data) elif format_type.lower() == 'bip': # BIP: 波段交错按像素 (spectral默认) pass # spectral默认就是BIP elif format_type.lower() == 'bsq': # BSQ: 波段顺序 data = self._convert_to_bsq(data) self.data = data self.header = header self.shape = data.shape self.dtype = data.dtype # 提取波长信息 if 'wavelength' in header: self.wavelengths = np.array([float(w) for w in header['wavelength']]) else: self.wavelengths = None print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}, 格式={format_type}") if self.wavelengths is not None: print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm") return data, header except Exception as e: raise IOError(f"加载图像失败: {e}") def _detect_format(self, file_path: str) -> str: """自动检测图像格式""" ext = Path(file_path).suffix.lower() if ext == '.hdr': # 读取hdr文件内容检测数据格式 with open(file_path, 'r') as f: content = f.read() if 'interleave = bil' in content.lower(): return 'bil' elif 'interleave = bsq' in content.lower(): return 'bsq' elif 'interleave = bip' in content.lower(): return 'bip' else: # 默认假设为BIP return 'bip' elif ext in ['.bil', '.bip', '.bsq', '.dat']: # 从扩展名判断 return ext[1:] # 去掉点号 else: # 默认假设为BIP return 'bip' def _convert_to_bil(self, data: np.ndarray) -> np.ndarray: """将数据转换为BIL格式""" rows, cols, bands = data.shape bil_data = np.zeros((rows, bands, cols), dtype=data.dtype) for i in range(rows): for b in range(bands): bil_data[i, b, :] = data[i, :, b] return bil_data def _convert_to_bsq(self, data: np.ndarray) -> np.ndarray: """将数据转换为BSQ格式""" rows, cols, bands = data.shape bsq_data = np.zeros((bands, rows, cols), dtype=data.dtype) for b in range(bands): bsq_data[b, :, :] = data[:, :, b] return bsq_data def extract_band(self, band_index: int) -> np.ndarray: """ 提取指定波段 Parameters: ----------- band_index : int 波段索引(从0开始) Returns: -------- band_data : np.ndarray 单波段图像 (rows, cols) """ if self.data is None: raise ValueError("请先加载图像数据") if band_index < 0 or band_index >= self.shape[2]: raise ValueError(f"波段索引 {band_index} 超出范围 [0, {self.shape[2]-1}]") # 根据数据布局提取波段 if len(self.data.shape) == 3: # 标准形状 (rows, cols, bands) # 使用squeeze()来确保移除单维度 band_data = np.squeeze(self.data[:, :, band_index]) print(f"提取前形状: {self.data.shape}, band_index: {band_index}") print(f"提取后初始形状: {band_data.shape}") # 确保最终是2D if len(band_data.shape) != 2: raise ValueError(f"无法提取2D波段数据,最终形状: {band_data.shape}") elif len(self.data.shape) == 2: # 已经是单波段 band_data = self.data else: raise ValueError(f"不支持的数据形状: {self.data.shape}") print(f"提取波段 {band_index}: 最终形状={band_data.shape}") return band_data def apply_morphology_operation(self, band_data: np.ndarray, operation: str = 'dilation', se_shape: str = 'disk', se_size: int = 3, **kwargs) -> np.ndarray: """ 应用形态学操作 Parameters: ----------- band_data : np.ndarray 单波段图像数据 (2D) operation : str 形态学操作类型: - 'dilation': 膨胀 - 'erosion': 腐蚀 - 'opening': 开运算 - 'closing': 闭运算 - 'gradient': 形态学梯度 (膨胀 - 腐蚀) - 'tophat': 顶帽变换 (原图 - 开运算) - 'bottomhat': 底帽变换 (闭运算 - 原图) - 'reconstruction': 形态学重建 se_shape : str 结构元素形状: 'disk', 'square', 'rectangle', 'diamond' se_size : int 结构元素大小 Returns: -------- result : np.ndarray 处理后的图像 """ # 确保输入数据是2D的 if len(band_data.shape) != 2: raise ValueError(f"输入数据必须是2D的,当前形状: {band_data.shape}") # 创建结构元素 selem = self._create_structuring_element(se_shape, se_size) # 转换为浮点数以保证精度 data_float = band_data.astype(np.float32) # 应用形态学操作 if operation == 'dilation': result = ndimage.grey_dilation(data_float, footprint=selem) elif operation == 'erosion': result = ndimage.grey_erosion(data_float, footprint=selem) elif operation == 'opening': # 开运算: 先腐蚀后膨胀 eroded = ndimage.grey_erosion(data_float, footprint=selem) result = ndimage.grey_dilation(eroded, footprint=selem) elif operation == 'closing': # 闭运算: 先膨胀后腐蚀 dilated = ndimage.grey_dilation(data_float, footprint=selem) result = ndimage.grey_erosion(dilated, footprint=selem) elif operation == 'gradient': # 形态学梯度: 膨胀 - 腐蚀 dilated = ndimage.grey_dilation(data_float, footprint=selem) eroded = ndimage.grey_erosion(data_float, footprint=selem) result = dilated - eroded elif operation == 'tophat': # 顶帽变换: 原图 - 开运算 eroded = ndimage.grey_erosion(data_float, footprint=selem) opened = ndimage.grey_dilation(eroded, footprint=selem) result = data_float - opened elif operation == 'bottomhat': # 底帽变换: 闭运算 - 原图 dilated = ndimage.grey_dilation(data_float, footprint=selem) closed = ndimage.grey_erosion(dilated, footprint=selem) result = closed - data_float elif operation == 'reconstruction': # 形态学重建(基于膨胀的重建) # 使用标记图像(这里使用腐蚀后的图像作为标记) marker = ndimage.grey_erosion(data_float, footprint=selem) result = self._morphological_reconstruction(marker, data_float, selem) else: raise ValueError(f"不支持的形态学操作: {operation}") # 转换为原始数据类型 if operation in ['gradient', 'tophat', 'bottomhat']: # 这些操作可能产生负值,保持浮点类型 return result else: return self._convert_to_original_type(result, band_data.dtype) def _create_structuring_element(self, shape: str, size: int) -> np.ndarray: """创建结构元素""" if shape == 'disk': # 创建圆形结构元素 radius = size // 2 y, x = np.ogrid[-radius:radius+1, -radius:radius+1] mask = x*x + y*y <= radius*radius selem = np.zeros((2*radius+1, 2*radius+1), dtype=bool) selem[mask] = True elif shape == 'square': # 创建方形结构元素 selem = np.ones((size, size), dtype=bool) elif shape == 'rectangle': # 创建矩形结构元素 selem = np.ones((size, size*2), dtype=bool) elif shape == 'diamond': # 创建菱形结构元素 selem = morphology.diamond(size) else: raise ValueError(f"不支持的结构元素形状: {shape}") return selem def _morphological_reconstruction(self, marker: np.ndarray, mask: np.ndarray, selem: np.ndarray) -> np.ndarray: """ 形态学重建(基于膨胀) Parameters: ----------- marker : np.ndarray 标记图像 mask : np.ndarray 掩模图像 selem : np.ndarray 结构元素 Returns: -------- reconstructed : np.ndarray 重建后的图像 """ # 确保标记图像不超过掩模图像 marker = np.minimum(marker, mask) # 迭代膨胀直到收敛 prev_marker = np.zeros_like(marker) while not np.array_equal(marker, prev_marker): prev_marker = marker.copy() # 条件膨胀:在掩模限制下膨胀 dilated = ndimage.grey_dilation(marker, footprint=selem) marker = np.minimum(dilated, mask) return marker def _convert_to_original_type(self, data: np.ndarray, original_dtype: np.dtype) -> np.ndarray: """将数据转换回原始数据类型""" if np.issubdtype(original_dtype, np.integer): # 对于整数类型,进行裁剪和取整 data = np.clip(data, np.iinfo(original_dtype).min, np.iinfo(original_dtype).max) return data.astype(original_dtype) else: # 对于浮点类型,直接转换 return data.astype(original_dtype) def apply_to_multiple_bands(self, band_indices: List[int], operation: str, se_shape: str = 'disk', se_size: int = 3) -> Dict[int, np.ndarray]: """ 对多个波段应用形态学操作 Parameters: ----------- band_indices : List[int] 波段索引列表 operation : str 形态学操作类型 se_shape : str 结构元素形状 se_size : int 结构元素大小 Returns: -------- results : Dict[int, np.ndarray] 每个波段的处理结果 """ results = {} for band_idx in band_indices: print(f"处理波段 {band_idx}...") band_data = self.extract_band(band_idx) # extract_band 现在应该总是返回2D数据,但为了安全起见检查一下 if len(band_data.shape) != 2: raise ValueError(f"波段 {band_idx} 数据不是2D的,形状: {band_data.shape}") result = self.apply_morphology_operation(band_data, operation, se_shape, se_size) results[band_idx] = result return results def save_as_envi(self, data: np.ndarray, output_path: str, description: str = "形态学处理结果") -> None: """ 保存为ENVI格式的dat和hdr文件 Parameters: ----------- data : np.ndarray 要保存的数据(单波段2D数组) output_path : str 输出文件路径(不含扩展名) description : str 图像描述 """ try: # 确保输出目录存在 os.makedirs(os.path.dirname(output_path), exist_ok=True) # 确保输出文件扩展名为.dat if not output_path.lower().endswith('.dat'): dat_path = output_path + '.dat' else: dat_path = output_path # 保存为二进制dat文件 data.tofile(dat_path) # 创建对应的hdr文件 hdr_path = dat_path.rsplit('.', 1)[0] + '.hdr' self._create_morphology_hdr_file(hdr_path, data.shape, description, data.dtype) print(f"ENVI格式已保存:") print(f" 数据文件: {dat_path}") print(f" 头文件: {hdr_path}") print(f" 数据形状: {data.shape}, 数据类型: {data.dtype}") except Exception as e: raise IOError(f"保存ENVI文件失败: {e}") def _create_morphology_hdr_file(self, hdr_path: str, data_shape: Tuple, description: str, data_dtype: np.dtype = None) -> None: """ 创建形态学处理结果的ENVI头文件 Parameters: ----------- hdr_path : str 头文件路径 data_shape : Tuple 数据形状 description : str 图像描述 data_dtype : np.dtype, optional 数据类型,如果为None则使用默认值 """ try: # 从原始头文件获取基本信息(如果有的话) if self.header is not None: # 复制原始头文件的关键信息 samples = self.header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0]) lines = self.header.get('lines', data_shape[0] if len(data_shape) > 1 else 1) bands = 1 # 形态学处理结果是单波段 interleave = 'bsq' # 根据数据类型确定ENVI数据类型 if data_dtype is not None: dtype = data_dtype else: dtype = np.float32 # 默认类型 if np.issubdtype(dtype, np.float32): data_type = 4 # float32 elif np.issubdtype(dtype, np.float64): data_type = 5 # float64 elif np.issubdtype(dtype, np.int32): data_type = 3 # int32 elif np.issubdtype(dtype, np.int16): data_type = 2 # int16 elif np.issubdtype(dtype, np.uint16): data_type = 12 # uint16 else: data_type = 4 # 默认float32 byte_order = self.header.get('byte order', 0) wavelength_units = self.header.get('wavelength units', 'nm') else: # 默认值(适用于2D形态学处理结果) if len(data_shape) == 2: lines, samples = data_shape else: lines = data_shape[0] samples = data_shape[1] if len(data_shape) > 1 else data_shape[0] bands = 1 interleave = 'bsq' data_type = 4 # float32(形态学处理通常涉及浮点数) byte_order = 0 wavelength_units = 'Unknown' # 写入hdr文件 with open(hdr_path, 'w') as f: f.write("ENVI\n") f.write("description = {\n") f.write(f" {description}\n") f.write(" 单波段形态学处理结果\n") f.write("}\n") f.write(f"samples = {samples}\n") f.write(f"lines = {lines}\n") f.write(f"bands = {bands}\n") f.write(f"header offset = 0\n") f.write(f"file type = ENVI Standard\n") f.write(f"data type = {data_type}\n") f.write(f"interleave = {interleave}\n") f.write(f"byte order = {byte_order}\n") # 如果有波长信息,添加虚拟波长 if self.header and 'wavelength' in self.header: f.write("wavelength = {\n") f.write(" 形态学处理结果\n") f.write("}\n") f.write(f"wavelength units = {wavelength_units}\n") except Exception as e: raise IOError(f"创建HDR文件失败: {e}") def visualize_results(self, original_band: np.ndarray, processed_band: np.ndarray, operation_name: str, save_path: Optional[str] = None) -> None: """ 可视化原始波段和处理后的波段 Parameters: ----------- original_band : np.ndarray 原始波段数据 processed_band : np.ndarray 处理后的波段数据 operation_name : str 操作名称 save_path : str, optional 保存图像路径 """ fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # 原始图像 im1 = axes[0].imshow(original_band, cmap='gray') axes[0].set_title('原始波段') axes[0].axis('off') plt.colorbar(im1, ax=axes[0], shrink=0.8) # 处理后的图像 im2 = axes[1].imshow(processed_band, cmap='gray') axes[1].set_title(f'{operation_name} 结果') axes[1].axis('off') plt.colorbar(im2, ax=axes[1], shrink=0.8) # 差异图像 diff = processed_band.astype(np.float32) - original_band.astype(np.float32) im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max()) axes[2].set_title('差异 (处理后 - 原始)') axes[2].axis('off') plt.colorbar(im3, ax=axes[2], shrink=0.8) plt.suptitle(f'形态学操作: {operation_name}', fontsize=16) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"可视化结果已保存到: {save_path}") plt.show() def main(): """主函数:高光谱图像形态学处理""" import argparse # 支持的形态学操作列表 morph_operations = [ 'dilation', 'erosion', 'opening', 'closing', 'gradient', 'tophat', 'bottomhat', 'reconstruction' ] # 支持的结构元素形状 se_shapes = ['disk', 'square', 'rectangle', 'diamond'] parser = argparse.ArgumentParser(description='高光谱图像形态学处理工具') parser.add_argument('input_file', help='输入高光谱图像文件路径') parser.add_argument('--format', '-f', default='auto', choices=['auto', 'bil', 'bip', 'bsq', 'dat'], help='图像格式 (默认: auto)') parser.add_argument('--band', '-b', type=int, default=0, help='要处理的波段索引 (默认: 0)') parser.add_argument('--bands', '-B', type=str, default=None, help='多个波段索引,用逗号分隔,如 "0,10,20"') parser.add_argument('--operation', '-o', default='dilation', choices=morph_operations, help=f'形态学操作类型 (默认: dilation)') parser.add_argument('--se_shape', '-s', default='disk', choices=se_shapes, help='结构元素形状 (默认: disk)') parser.add_argument('--se_size', '-S', type=int, default=3, help='结构元素大小 (默认: 3)') parser.add_argument('--output_dir', '-d', default='output', help='输出目录 (默认: output)') parser.add_argument('--output_name', '-n', default='morphology_result', help='输出文件名 (不含扩展名) (默认: morphology_result)') parser.add_argument('--visualize', '-v', action='store_true', help='是否生成可视化结果') args = parser.parse_args() try: # 初始化处理器 processor = HyperspectralMorphologyProcessor() # 确定格式 format_type = None if args.format == 'auto' else args.format # 加载图像 print(f"加载图像: {args.input_file}") data, header = processor.load_hyperspectral_image(args.input_file, format_type) # 确定要处理的波段 if args.bands: band_indices = [int(b.strip()) for b in args.bands.split(',')] print(f"处理波段: {band_indices}") single_band = False else: band_indices = [args.band] print(f"处理波段: {args.band}") single_band = True # 应用形态学操作 if single_band: # 单波段处理 band_data = processor.extract_band(args.band) print(f"主函数中波段数据形状: {band_data.shape}") result = processor.apply_morphology_operation( band_data, args.operation, args.se_shape, args.se_size ) # 保存结果 output_path = os.path.join(args.output_dir, f"{args.output_name}_band{args.band}") processor.save_as_envi(result, output_path, f"{args.operation.capitalize()} 处理结果 - 波段 {args.band}") # 可视化 if args.visualize: vis_path = os.path.join(args.output_dir, f"visualization_band{args.band}.png") processor.visualize_results(band_data, result, args.operation, vis_path) # 打印统计信息 print(f"\n=== 统计信息 ===") print(f"原始波段 {args.band}: min={band_data.min():.4f}, max={band_data.max():.4f}, " f"mean={band_data.mean():.4f}") print(f"处理后: min={result.min():.4f}, max={result.max():.4f}, " f"mean={result.mean():.4f}") else: # 多波段处理 results = processor.apply_to_multiple_bands( band_indices, args.operation, args.se_shape, args.se_size ) # 保存每个波段的结果 for band_idx, result in results.items(): output_path = os.path.join(args.output_dir, f"{args.output_name}_band{band_idx}") processor.save_as_envi(result, output_path, f"{args.operation.capitalize()} 处理结果 - 波段 {band_idx}") # 如果只有少量波段,可以创建组合可视化 if len(band_indices) <= 4 and args.visualize: fig, axes = plt.subplots(2, 2, figsize=(12, 10)) axes = axes.flatten() for idx, band_idx in enumerate(band_indices[:4]): original_band = processor.extract_band(band_idx) processed_band = results[band_idx] axes[idx].imshow(processed_band, cmap='gray') axes[idx].set_title(f'波段 {band_idx} - {args.operation}') axes[idx].axis('off') plt.suptitle(f'多波段形态学处理: {args.operation}', fontsize=16) plt.tight_layout() vis_path = os.path.join(args.output_dir, "multi_band_visualization.png") plt.savefig(vis_path, dpi=300, bbox_inches='tight') print(f"多波段可视化结果已保存到: {vis_path}") plt.show() print(f"\n✓ 形态学处理完成!") print(f"输出目录: {args.output_dir}") return 0 except Exception as e: print(f"✗ 处理失败: {e}") import traceback traceback.print_exc() return 1 # 使用示例函数 def run_example(): """运行示例""" # 示例1:对单个波段进行膨胀操作 processor = HyperspectralMorphologyProcessor() # 加载图像(请替换为您的实际文件路径) input_file = "path/to/your/hyperspectral_image.hdr" # 或 .bil, .bip, .bsq, .dat data, header = processor.load_hyperspectral_image(input_file, format_type='bip') # 提取波段(例如,提取第50个波段) band_idx = 50 band_data = processor.extract_band(band_idx) # 应用不同的形态学操作 operations = ['dilation', 'erosion', 'opening', 'closing', 'gradient', 'tophat', 'bottomhat', 'reconstruction'] results = {} for operation in operations: print(f"\n应用 {operation} 操作...") try: result = processor.apply_morphology_operation( band_data, operation, se_shape='disk', se_size=3 ) results[operation] = result # 保存结果 output_path = f"output/{operation}_band{band_idx}" processor.save_as_envi(result, output_path, f"{operation.capitalize()} 处理结果") # 可视化 processor.visualize_results(band_data, result, operation, save_path=f"output/visualization_{operation}.png") except Exception as e: print(f"操作 {operation} 失败: {e}") return results @dataclass class MorphologicalFilterConfig: """形态学滤波配置类""" input_path: str operation: str band_index: int kernel_size: int output_dir: str se_shape: str = 'disk' se_size: int = 3 format_type: str = 'auto' if __name__ == '__main__': exit(main())