Files
HSI/fliter_method/morphological_fliter.py

746 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import os
import warnings
from typing import Tuple, List, Dict, Optional, Union
from pathlib import Path
from dataclasses import dataclass
import spectral as spy
from spectral import envi
import matplotlib.pyplot as plt
from skimage import morphology
from scipy import ndimage
import struct
warnings.filterwarnings('ignore')
class HyperspectralMorphologyProcessor:
"""高光谱图像形态学处理类"""
def __init__(self):
self.data = None
self.header = None
self.wavelengths = None
self.shape = None
self.dtype = None
def load_hyperspectral_image(self, file_path: str, format_type: str = None) -> Tuple[np.ndarray, Dict]:
"""
加载高光谱图像支持bil、bip、bsq、dat格式
Parameters:
-----------
file_path : str
图像文件路径(可以是.hdr、.bil、.bip、.bsq、.dat
format_type : str, optional
格式类型('bil''bip''bsq''dat'如果为None则自动检测
Returns:
--------
data : np.ndarray
高光谱数据立方体 (rows, cols, bands)
header : Dict
头文件信息
"""
try:
# 自动检测格式
if format_type is None:
format_type = self._detect_format(file_path)
# 如果是.hdr文件直接使用spectral加载
if file_path.lower().endswith('.hdr'):
img = envi.open(file_path)
data = img.load()
header = dict(img.metadata)
# 如果是其他格式,尝试查找对应的.hdr文件
else:
# 查找可能的hdr文件
hdr_candidates = [
file_path.rsplit('.', 1)[0] + '.hdr',
file_path + '.hdr',
file_path[:-4] + '.hdr' if len(file_path) > 4 else None
]
hdr_file = None
for candidate in hdr_candidates:
if candidate and os.path.exists(candidate):
hdr_file = candidate
break
if hdr_file is None:
raise FileNotFoundError(f"未找到头文件(.hdr)用于 {file_path}")
# 加载图像
img = envi.open(hdr_file)
data = img.load()
header = dict(img.metadata)
# 根据格式调整数据布局
if format_type.lower() == 'bil':
# BIL: 波段交错按行
data = self._convert_to_bil(data)
elif format_type.lower() == 'bip':
# BIP: 波段交错按像素 (spectral默认)
pass # spectral默认就是BIP
elif format_type.lower() == 'bsq':
# BSQ: 波段顺序
data = self._convert_to_bsq(data)
self.data = data
self.header = header
self.shape = data.shape
self.dtype = data.dtype
# 提取波长信息
if 'wavelength' in header:
self.wavelengths = np.array([float(w) for w in header['wavelength']])
else:
self.wavelengths = None
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}, 格式={format_type}")
if self.wavelengths is not None:
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def _detect_format(self, file_path: str) -> str:
"""自动检测图像格式"""
ext = Path(file_path).suffix.lower()
if ext == '.hdr':
# 读取hdr文件内容检测数据格式
with open(file_path, 'r') as f:
content = f.read()
if 'interleave = bil' in content.lower():
return 'bil'
elif 'interleave = bsq' in content.lower():
return 'bsq'
elif 'interleave = bip' in content.lower():
return 'bip'
else:
# 默认假设为BIP
return 'bip'
elif ext in ['.bil', '.bip', '.bsq', '.dat']:
# 从扩展名判断
return ext[1:] # 去掉点号
else:
# 默认假设为BIP
return 'bip'
def _convert_to_bil(self, data: np.ndarray) -> np.ndarray:
"""将数据转换为BIL格式"""
rows, cols, bands = data.shape
bil_data = np.zeros((rows, bands, cols), dtype=data.dtype)
for i in range(rows):
for b in range(bands):
bil_data[i, b, :] = data[i, :, b]
return bil_data
def _convert_to_bsq(self, data: np.ndarray) -> np.ndarray:
"""将数据转换为BSQ格式"""
rows, cols, bands = data.shape
bsq_data = np.zeros((bands, rows, cols), dtype=data.dtype)
for b in range(bands):
bsq_data[b, :, :] = data[:, :, b]
return bsq_data
def extract_band(self, band_index: int) -> np.ndarray:
"""
提取指定波段
Parameters:
-----------
band_index : int
波段索引从0开始
Returns:
--------
band_data : np.ndarray
单波段图像 (rows, cols)
"""
if self.data is None:
raise ValueError("请先加载图像数据")
if band_index < 0 or band_index >= self.shape[2]:
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {self.shape[2]-1}]")
# 根据数据布局提取波段
if len(self.data.shape) == 3:
# 标准形状 (rows, cols, bands)
# 使用squeeze()来确保移除单维度
band_data = np.squeeze(self.data[:, :, band_index])
print(f"提取前形状: {self.data.shape}, band_index: {band_index}")
print(f"提取后初始形状: {band_data.shape}")
# 确保最终是2D
if len(band_data.shape) != 2:
raise ValueError(f"无法提取2D波段数据最终形状: {band_data.shape}")
elif len(self.data.shape) == 2:
# 已经是单波段
band_data = self.data
else:
raise ValueError(f"不支持的数据形状: {self.data.shape}")
print(f"提取波段 {band_index}: 最终形状={band_data.shape}")
return band_data
def apply_morphology_operation(self, band_data: np.ndarray,
operation: str = 'dilation',
se_shape: str = 'disk',
se_size: int = 3,
**kwargs) -> np.ndarray:
"""
应用形态学操作
Parameters:
-----------
band_data : np.ndarray
单波段图像数据 (2D)
operation : str
形态学操作类型:
- 'dilation': 膨胀
- 'erosion': 腐蚀
- 'opening': 开运算
- 'closing': 闭运算
- 'gradient': 形态学梯度 (膨胀 - 腐蚀)
- 'tophat': 顶帽变换 (原图 - 开运算)
- 'bottomhat': 底帽变换 (闭运算 - 原图)
- 'reconstruction': 形态学重建
se_shape : str
结构元素形状: 'disk', 'square', 'rectangle', 'diamond'
se_size : int
结构元素大小
Returns:
--------
result : np.ndarray
处理后的图像
"""
# 确保输入数据是2D的
if len(band_data.shape) != 2:
raise ValueError(f"输入数据必须是2D的当前形状: {band_data.shape}")
# 创建结构元素
selem = self._create_structuring_element(se_shape, se_size)
# 转换为浮点数以保证精度
data_float = band_data.astype(np.float32)
# 应用形态学操作
if operation == 'dilation':
result = ndimage.grey_dilation(data_float, footprint=selem)
elif operation == 'erosion':
result = ndimage.grey_erosion(data_float, footprint=selem)
elif operation == 'opening':
# 开运算: 先腐蚀后膨胀
eroded = ndimage.grey_erosion(data_float, footprint=selem)
result = ndimage.grey_dilation(eroded, footprint=selem)
elif operation == 'closing':
# 闭运算: 先膨胀后腐蚀
dilated = ndimage.grey_dilation(data_float, footprint=selem)
result = ndimage.grey_erosion(dilated, footprint=selem)
elif operation == 'gradient':
# 形态学梯度: 膨胀 - 腐蚀
dilated = ndimage.grey_dilation(data_float, footprint=selem)
eroded = ndimage.grey_erosion(data_float, footprint=selem)
result = dilated - eroded
elif operation == 'tophat':
# 顶帽变换: 原图 - 开运算
eroded = ndimage.grey_erosion(data_float, footprint=selem)
opened = ndimage.grey_dilation(eroded, footprint=selem)
result = data_float - opened
elif operation == 'bottomhat':
# 底帽变换: 闭运算 - 原图
dilated = ndimage.grey_dilation(data_float, footprint=selem)
closed = ndimage.grey_erosion(dilated, footprint=selem)
result = closed - data_float
elif operation == 'reconstruction':
# 形态学重建(基于膨胀的重建)
# 使用标记图像(这里使用腐蚀后的图像作为标记)
marker = ndimage.grey_erosion(data_float, footprint=selem)
result = self._morphological_reconstruction(marker, data_float, selem)
else:
raise ValueError(f"不支持的形态学操作: {operation}")
# 转换为原始数据类型
if operation in ['gradient', 'tophat', 'bottomhat']:
# 这些操作可能产生负值,保持浮点类型
return result
else:
return self._convert_to_original_type(result, band_data.dtype)
def _create_structuring_element(self, shape: str, size: int) -> np.ndarray:
"""创建结构元素"""
if shape == 'disk':
# 创建圆形结构元素
radius = size // 2
y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
mask = x*x + y*y <= radius*radius
selem = np.zeros((2*radius+1, 2*radius+1), dtype=bool)
selem[mask] = True
elif shape == 'square':
# 创建方形结构元素
selem = np.ones((size, size), dtype=bool)
elif shape == 'rectangle':
# 创建矩形结构元素
selem = np.ones((size, size*2), dtype=bool)
elif shape == 'diamond':
# 创建菱形结构元素
selem = morphology.diamond(size)
else:
raise ValueError(f"不支持的结构元素形状: {shape}")
return selem
def _morphological_reconstruction(self, marker: np.ndarray, mask: np.ndarray,
selem: np.ndarray) -> np.ndarray:
"""
形态学重建(基于膨胀)
Parameters:
-----------
marker : np.ndarray
标记图像
mask : np.ndarray
掩模图像
selem : np.ndarray
结构元素
Returns:
--------
reconstructed : np.ndarray
重建后的图像
"""
# 确保标记图像不超过掩模图像
marker = np.minimum(marker, mask)
# 迭代膨胀直到收敛
prev_marker = np.zeros_like(marker)
while not np.array_equal(marker, prev_marker):
prev_marker = marker.copy()
# 条件膨胀:在掩模限制下膨胀
dilated = ndimage.grey_dilation(marker, footprint=selem)
marker = np.minimum(dilated, mask)
return marker
def _convert_to_original_type(self, data: np.ndarray, original_dtype: np.dtype) -> np.ndarray:
"""将数据转换回原始数据类型"""
if np.issubdtype(original_dtype, np.integer):
# 对于整数类型,进行裁剪和取整
data = np.clip(data, np.iinfo(original_dtype).min, np.iinfo(original_dtype).max)
return data.astype(original_dtype)
else:
# 对于浮点类型,直接转换
return data.astype(original_dtype)
def apply_to_multiple_bands(self, band_indices: List[int], operation: str,
se_shape: str = 'disk', se_size: int = 3) -> Dict[int, np.ndarray]:
"""
对多个波段应用形态学操作
Parameters:
-----------
band_indices : List[int]
波段索引列表
operation : str
形态学操作类型
se_shape : str
结构元素形状
se_size : int
结构元素大小
Returns:
--------
results : Dict[int, np.ndarray]
每个波段的处理结果
"""
results = {}
for band_idx in band_indices:
print(f"处理波段 {band_idx}...")
band_data = self.extract_band(band_idx)
# extract_band 现在应该总是返回2D数据但为了安全起见检查一下
if len(band_data.shape) != 2:
raise ValueError(f"波段 {band_idx} 数据不是2D的形状: {band_data.shape}")
result = self.apply_morphology_operation(band_data, operation, se_shape, se_size)
results[band_idx] = result
return results
def save_as_envi(self, data: np.ndarray, output_path: str,
description: str = "形态学处理结果") -> None:
"""
保存为ENVI格式的dat和hdr文件
Parameters:
-----------
data : np.ndarray
要保存的数据单波段2D数组
output_path : str
输出文件路径(不含扩展名)
description : str
图像描述
"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 确保输出文件扩展名为.dat
if not output_path.lower().endswith('.dat'):
dat_path = output_path + '.dat'
else:
dat_path = output_path
# 保存为二进制dat文件
data.tofile(dat_path)
# 创建对应的hdr文件
hdr_path = dat_path.rsplit('.', 1)[0] + '.hdr'
self._create_morphology_hdr_file(hdr_path, data.shape, description, data.dtype)
print(f"ENVI格式已保存:")
print(f" 数据文件: {dat_path}")
print(f" 头文件: {hdr_path}")
print(f" 数据形状: {data.shape}, 数据类型: {data.dtype}")
except Exception as e:
raise IOError(f"保存ENVI文件失败: {e}")
def _create_morphology_hdr_file(self, hdr_path: str, data_shape: Tuple,
description: str, data_dtype: np.dtype = None) -> None:
"""
创建形态学处理结果的ENVI头文件
Parameters:
-----------
hdr_path : str
头文件路径
data_shape : Tuple
数据形状
description : str
图像描述
data_dtype : np.dtype, optional
数据类型如果为None则使用默认值
"""
try:
# 从原始头文件获取基本信息(如果有的话)
if self.header is not None:
# 复制原始头文件的关键信息
samples = self.header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
lines = self.header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
bands = 1 # 形态学处理结果是单波段
interleave = 'bsq'
# 根据数据类型确定ENVI数据类型
if data_dtype is not None:
dtype = data_dtype
else:
dtype = np.float32 # 默认类型
if np.issubdtype(dtype, np.float32):
data_type = 4 # float32
elif np.issubdtype(dtype, np.float64):
data_type = 5 # float64
elif np.issubdtype(dtype, np.int32):
data_type = 3 # int32
elif np.issubdtype(dtype, np.int16):
data_type = 2 # int16
elif np.issubdtype(dtype, np.uint16):
data_type = 12 # uint16
else:
data_type = 4 # 默认float32
byte_order = self.header.get('byte order', 0)
wavelength_units = self.header.get('wavelength units', 'nm')
else:
# 默认值适用于2D形态学处理结果
if len(data_shape) == 2:
lines, samples = data_shape
else:
lines = data_shape[0]
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
bands = 1
interleave = 'bsq'
data_type = 4 # float32形态学处理通常涉及浮点数
byte_order = 0
wavelength_units = 'Unknown'
# 写入hdr文件
with open(hdr_path, 'w') as f:
f.write("ENVI\n")
f.write("description = {\n")
f.write(f" {description}\n")
f.write(" 单波段形态学处理结果\n")
f.write("}\n")
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write(f"header offset = 0\n")
f.write(f"file type = ENVI Standard\n")
f.write(f"data type = {data_type}\n")
f.write(f"interleave = {interleave}\n")
f.write(f"byte order = {byte_order}\n")
# 如果有波长信息,添加虚拟波长
if self.header and 'wavelength' in self.header:
f.write("wavelength = {\n")
f.write(" 形态学处理结果\n")
f.write("}\n")
f.write(f"wavelength units = {wavelength_units}\n")
except Exception as e:
raise IOError(f"创建HDR文件失败: {e}")
def visualize_results(self, original_band: np.ndarray,
processed_band: np.ndarray,
operation_name: str,
save_path: Optional[str] = None) -> None:
"""
可视化原始波段和处理后的波段
Parameters:
-----------
original_band : np.ndarray
原始波段数据
processed_band : np.ndarray
处理后的波段数据
operation_name : str
操作名称
save_path : str, optional
保存图像路径
"""
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 原始图像
im1 = axes[0].imshow(original_band, cmap='gray')
axes[0].set_title('原始波段')
axes[0].axis('off')
plt.colorbar(im1, ax=axes[0], shrink=0.8)
# 处理后的图像
im2 = axes[1].imshow(processed_band, cmap='gray')
axes[1].set_title(f'{operation_name} 结果')
axes[1].axis('off')
plt.colorbar(im2, ax=axes[1], shrink=0.8)
# 差异图像
diff = processed_band.astype(np.float32) - original_band.astype(np.float32)
im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
axes[2].set_title('差异 (处理后 - 原始)')
axes[2].axis('off')
plt.colorbar(im3, ax=axes[2], shrink=0.8)
plt.suptitle(f'形态学操作: {operation_name}', fontsize=16)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"可视化结果已保存到: {save_path}")
plt.show()
def main():
"""主函数:高光谱图像形态学处理"""
import argparse
# 支持的形态学操作列表
morph_operations = [
'dilation', 'erosion', 'opening', 'closing',
'gradient', 'tophat', 'bottomhat', 'reconstruction'
]
# 支持的结构元素形状
se_shapes = ['disk', 'square', 'rectangle', 'diamond']
parser = argparse.ArgumentParser(description='高光谱图像形态学处理工具')
parser.add_argument('input_file', help='输入高光谱图像文件路径')
parser.add_argument('--format', '-f', default='auto',
choices=['auto', 'bil', 'bip', 'bsq', 'dat'],
help='图像格式 (默认: auto)')
parser.add_argument('--band', '-b', type=int, default=0,
help='要处理的波段索引 (默认: 0)')
parser.add_argument('--bands', '-B', type=str, default=None,
help='多个波段索引,用逗号分隔,如 "0,10,20"')
parser.add_argument('--operation', '-o', default='dilation',
choices=morph_operations,
help=f'形态学操作类型 (默认: dilation)')
parser.add_argument('--se_shape', '-s', default='disk',
choices=se_shapes,
help='结构元素形状 (默认: disk)')
parser.add_argument('--se_size', '-S', type=int, default=3,
help='结构元素大小 (默认: 3)')
parser.add_argument('--output_dir', '-d', default='output',
help='输出目录 (默认: output)')
parser.add_argument('--output_name', '-n', default='morphology_result',
help='输出文件名 (不含扩展名) (默认: morphology_result)')
parser.add_argument('--visualize', '-v', action='store_true',
help='是否生成可视化结果')
args = parser.parse_args()
try:
# 初始化处理器
processor = HyperspectralMorphologyProcessor()
# 确定格式
format_type = None if args.format == 'auto' else args.format
# 加载图像
print(f"加载图像: {args.input_file}")
data, header = processor.load_hyperspectral_image(args.input_file, format_type)
# 确定要处理的波段
if args.bands:
band_indices = [int(b.strip()) for b in args.bands.split(',')]
print(f"处理波段: {band_indices}")
single_band = False
else:
band_indices = [args.band]
print(f"处理波段: {args.band}")
single_band = True
# 应用形态学操作
if single_band:
# 单波段处理
band_data = processor.extract_band(args.band)
print(f"主函数中波段数据形状: {band_data.shape}")
result = processor.apply_morphology_operation(
band_data, args.operation, args.se_shape, args.se_size
)
# 保存结果
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{args.band}")
processor.save_as_envi(result, output_path,
f"{args.operation.capitalize()} 处理结果 - 波段 {args.band}")
# 可视化
if args.visualize:
vis_path = os.path.join(args.output_dir, f"visualization_band{args.band}.png")
processor.visualize_results(band_data, result, args.operation, vis_path)
# 打印统计信息
print(f"\n=== 统计信息 ===")
print(f"原始波段 {args.band}: min={band_data.min():.4f}, max={band_data.max():.4f}, "
f"mean={band_data.mean():.4f}")
print(f"处理后: min={result.min():.4f}, max={result.max():.4f}, "
f"mean={result.mean():.4f}")
else:
# 多波段处理
results = processor.apply_to_multiple_bands(
band_indices, args.operation, args.se_shape, args.se_size
)
# 保存每个波段的结果
for band_idx, result in results.items():
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{band_idx}")
processor.save_as_envi(result, output_path,
f"{args.operation.capitalize()} 处理结果 - 波段 {band_idx}")
# 如果只有少量波段,可以创建组合可视化
if len(band_indices) <= 4 and args.visualize:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for idx, band_idx in enumerate(band_indices[:4]):
original_band = processor.extract_band(band_idx)
processed_band = results[band_idx]
axes[idx].imshow(processed_band, cmap='gray')
axes[idx].set_title(f'波段 {band_idx} - {args.operation}')
axes[idx].axis('off')
plt.suptitle(f'多波段形态学处理: {args.operation}', fontsize=16)
plt.tight_layout()
vis_path = os.path.join(args.output_dir, "multi_band_visualization.png")
plt.savefig(vis_path, dpi=300, bbox_inches='tight')
print(f"多波段可视化结果已保存到: {vis_path}")
plt.show()
print(f"\n✓ 形态学处理完成!")
print(f"输出目录: {args.output_dir}")
return 0
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
# 使用示例函数
def run_example():
"""运行示例"""
# 示例1对单个波段进行膨胀操作
processor = HyperspectralMorphologyProcessor()
# 加载图像(请替换为您的实际文件路径)
input_file = "path/to/your/hyperspectral_image.hdr" # 或 .bil, .bip, .bsq, .dat
data, header = processor.load_hyperspectral_image(input_file, format_type='bip')
# 提取波段例如提取第50个波段
band_idx = 50
band_data = processor.extract_band(band_idx)
# 应用不同的形态学操作
operations = ['dilation', 'erosion', 'opening', 'closing',
'gradient', 'tophat', 'bottomhat', 'reconstruction']
results = {}
for operation in operations:
print(f"\n应用 {operation} 操作...")
try:
result = processor.apply_morphology_operation(
band_data, operation, se_shape='disk', se_size=3
)
results[operation] = result
# 保存结果
output_path = f"output/{operation}_band{band_idx}"
processor.save_as_envi(result, output_path,
f"{operation.capitalize()} 处理结果")
# 可视化
processor.visualize_results(band_data, result, operation,
save_path=f"output/visualization_{operation}.png")
except Exception as e:
print(f"操作 {operation} 失败: {e}")
return results
@dataclass
class MorphologicalFilterConfig:
"""形态学滤波配置类"""
input_path: str
operation: str
band_index: int
kernel_size: int
output_dir: str
se_shape: str = 'disk'
se_size: int = 3
format_type: str = 'auto'
if __name__ == '__main__':
exit(main())