746 lines
28 KiB
Python
746 lines
28 KiB
Python
import numpy as np
|
||
import os
|
||
import warnings
|
||
from typing import Tuple, List, Dict, Optional, Union
|
||
from pathlib import Path
|
||
from dataclasses import dataclass
|
||
import spectral as spy
|
||
from spectral import envi
|
||
import matplotlib.pyplot as plt
|
||
from skimage import morphology
|
||
from scipy import ndimage
|
||
import struct
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
class HyperspectralMorphologyProcessor:
|
||
"""高光谱图像形态学处理类"""
|
||
|
||
def __init__(self):
|
||
self.data = None
|
||
self.header = None
|
||
self.wavelengths = None
|
||
self.shape = None
|
||
self.dtype = None
|
||
|
||
def load_hyperspectral_image(self, file_path: str, format_type: str = None) -> Tuple[np.ndarray, Dict]:
|
||
"""
|
||
加载高光谱图像(支持bil、bip、bsq、dat格式)
|
||
|
||
Parameters:
|
||
-----------
|
||
file_path : str
|
||
图像文件路径(可以是.hdr、.bil、.bip、.bsq、.dat)
|
||
format_type : str, optional
|
||
格式类型('bil'、'bip'、'bsq'、'dat'),如果为None则自动检测
|
||
|
||
Returns:
|
||
--------
|
||
data : np.ndarray
|
||
高光谱数据立方体 (rows, cols, bands)
|
||
header : Dict
|
||
头文件信息
|
||
"""
|
||
try:
|
||
# 自动检测格式
|
||
if format_type is None:
|
||
format_type = self._detect_format(file_path)
|
||
|
||
# 如果是.hdr文件,直接使用spectral加载
|
||
if file_path.lower().endswith('.hdr'):
|
||
img = envi.open(file_path)
|
||
data = img.load()
|
||
header = dict(img.metadata)
|
||
|
||
# 如果是其他格式,尝试查找对应的.hdr文件
|
||
else:
|
||
# 查找可能的hdr文件
|
||
hdr_candidates = [
|
||
file_path.rsplit('.', 1)[0] + '.hdr',
|
||
file_path + '.hdr',
|
||
file_path[:-4] + '.hdr' if len(file_path) > 4 else None
|
||
]
|
||
|
||
hdr_file = None
|
||
for candidate in hdr_candidates:
|
||
if candidate and os.path.exists(candidate):
|
||
hdr_file = candidate
|
||
break
|
||
|
||
if hdr_file is None:
|
||
raise FileNotFoundError(f"未找到头文件(.hdr)用于 {file_path}")
|
||
|
||
# 加载图像
|
||
img = envi.open(hdr_file)
|
||
data = img.load()
|
||
header = dict(img.metadata)
|
||
|
||
# 根据格式调整数据布局
|
||
if format_type.lower() == 'bil':
|
||
# BIL: 波段交错按行
|
||
data = self._convert_to_bil(data)
|
||
elif format_type.lower() == 'bip':
|
||
# BIP: 波段交错按像素 (spectral默认)
|
||
pass # spectral默认就是BIP
|
||
elif format_type.lower() == 'bsq':
|
||
# BSQ: 波段顺序
|
||
data = self._convert_to_bsq(data)
|
||
|
||
self.data = data
|
||
self.header = header
|
||
self.shape = data.shape
|
||
self.dtype = data.dtype
|
||
|
||
# 提取波长信息
|
||
if 'wavelength' in header:
|
||
self.wavelengths = np.array([float(w) for w in header['wavelength']])
|
||
else:
|
||
self.wavelengths = None
|
||
|
||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}, 格式={format_type}")
|
||
if self.wavelengths is not None:
|
||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||
|
||
return data, header
|
||
|
||
except Exception as e:
|
||
raise IOError(f"加载图像失败: {e}")
|
||
|
||
def _detect_format(self, file_path: str) -> str:
|
||
"""自动检测图像格式"""
|
||
ext = Path(file_path).suffix.lower()
|
||
|
||
if ext == '.hdr':
|
||
# 读取hdr文件内容检测数据格式
|
||
with open(file_path, 'r') as f:
|
||
content = f.read()
|
||
if 'interleave = bil' in content.lower():
|
||
return 'bil'
|
||
elif 'interleave = bsq' in content.lower():
|
||
return 'bsq'
|
||
elif 'interleave = bip' in content.lower():
|
||
return 'bip'
|
||
else:
|
||
# 默认假设为BIP
|
||
return 'bip'
|
||
elif ext in ['.bil', '.bip', '.bsq', '.dat']:
|
||
# 从扩展名判断
|
||
return ext[1:] # 去掉点号
|
||
else:
|
||
# 默认假设为BIP
|
||
return 'bip'
|
||
|
||
def _convert_to_bil(self, data: np.ndarray) -> np.ndarray:
|
||
"""将数据转换为BIL格式"""
|
||
rows, cols, bands = data.shape
|
||
bil_data = np.zeros((rows, bands, cols), dtype=data.dtype)
|
||
for i in range(rows):
|
||
for b in range(bands):
|
||
bil_data[i, b, :] = data[i, :, b]
|
||
return bil_data
|
||
|
||
def _convert_to_bsq(self, data: np.ndarray) -> np.ndarray:
|
||
"""将数据转换为BSQ格式"""
|
||
rows, cols, bands = data.shape
|
||
bsq_data = np.zeros((bands, rows, cols), dtype=data.dtype)
|
||
for b in range(bands):
|
||
bsq_data[b, :, :] = data[:, :, b]
|
||
return bsq_data
|
||
|
||
def extract_band(self, band_index: int) -> np.ndarray:
|
||
"""
|
||
提取指定波段
|
||
|
||
Parameters:
|
||
-----------
|
||
band_index : int
|
||
波段索引(从0开始)
|
||
|
||
Returns:
|
||
--------
|
||
band_data : np.ndarray
|
||
单波段图像 (rows, cols)
|
||
"""
|
||
if self.data is None:
|
||
raise ValueError("请先加载图像数据")
|
||
|
||
if band_index < 0 or band_index >= self.shape[2]:
|
||
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {self.shape[2]-1}]")
|
||
|
||
# 根据数据布局提取波段
|
||
if len(self.data.shape) == 3:
|
||
# 标准形状 (rows, cols, bands)
|
||
# 使用squeeze()来确保移除单维度
|
||
band_data = np.squeeze(self.data[:, :, band_index])
|
||
print(f"提取前形状: {self.data.shape}, band_index: {band_index}")
|
||
print(f"提取后初始形状: {band_data.shape}")
|
||
|
||
# 确保最终是2D
|
||
if len(band_data.shape) != 2:
|
||
raise ValueError(f"无法提取2D波段数据,最终形状: {band_data.shape}")
|
||
|
||
elif len(self.data.shape) == 2:
|
||
# 已经是单波段
|
||
band_data = self.data
|
||
else:
|
||
raise ValueError(f"不支持的数据形状: {self.data.shape}")
|
||
|
||
print(f"提取波段 {band_index}: 最终形状={band_data.shape}")
|
||
|
||
return band_data
|
||
|
||
def apply_morphology_operation(self, band_data: np.ndarray,
|
||
operation: str = 'dilation',
|
||
se_shape: str = 'disk',
|
||
se_size: int = 3,
|
||
**kwargs) -> np.ndarray:
|
||
"""
|
||
应用形态学操作
|
||
|
||
Parameters:
|
||
-----------
|
||
band_data : np.ndarray
|
||
单波段图像数据 (2D)
|
||
operation : str
|
||
形态学操作类型:
|
||
- 'dilation': 膨胀
|
||
- 'erosion': 腐蚀
|
||
- 'opening': 开运算
|
||
- 'closing': 闭运算
|
||
- 'gradient': 形态学梯度 (膨胀 - 腐蚀)
|
||
- 'tophat': 顶帽变换 (原图 - 开运算)
|
||
- 'bottomhat': 底帽变换 (闭运算 - 原图)
|
||
- 'reconstruction': 形态学重建
|
||
se_shape : str
|
||
结构元素形状: 'disk', 'square', 'rectangle', 'diamond'
|
||
se_size : int
|
||
结构元素大小
|
||
|
||
Returns:
|
||
--------
|
||
result : np.ndarray
|
||
处理后的图像
|
||
"""
|
||
# 确保输入数据是2D的
|
||
if len(band_data.shape) != 2:
|
||
raise ValueError(f"输入数据必须是2D的,当前形状: {band_data.shape}")
|
||
|
||
# 创建结构元素
|
||
selem = self._create_structuring_element(se_shape, se_size)
|
||
|
||
# 转换为浮点数以保证精度
|
||
data_float = band_data.astype(np.float32)
|
||
|
||
# 应用形态学操作
|
||
if operation == 'dilation':
|
||
result = ndimage.grey_dilation(data_float, footprint=selem)
|
||
|
||
elif operation == 'erosion':
|
||
result = ndimage.grey_erosion(data_float, footprint=selem)
|
||
|
||
elif operation == 'opening':
|
||
# 开运算: 先腐蚀后膨胀
|
||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||
result = ndimage.grey_dilation(eroded, footprint=selem)
|
||
|
||
elif operation == 'closing':
|
||
# 闭运算: 先膨胀后腐蚀
|
||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||
result = ndimage.grey_erosion(dilated, footprint=selem)
|
||
|
||
elif operation == 'gradient':
|
||
# 形态学梯度: 膨胀 - 腐蚀
|
||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||
result = dilated - eroded
|
||
|
||
elif operation == 'tophat':
|
||
# 顶帽变换: 原图 - 开运算
|
||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||
opened = ndimage.grey_dilation(eroded, footprint=selem)
|
||
result = data_float - opened
|
||
|
||
elif operation == 'bottomhat':
|
||
# 底帽变换: 闭运算 - 原图
|
||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||
closed = ndimage.grey_erosion(dilated, footprint=selem)
|
||
result = closed - data_float
|
||
|
||
elif operation == 'reconstruction':
|
||
# 形态学重建(基于膨胀的重建)
|
||
# 使用标记图像(这里使用腐蚀后的图像作为标记)
|
||
marker = ndimage.grey_erosion(data_float, footprint=selem)
|
||
result = self._morphological_reconstruction(marker, data_float, selem)
|
||
|
||
else:
|
||
raise ValueError(f"不支持的形态学操作: {operation}")
|
||
|
||
# 转换为原始数据类型
|
||
if operation in ['gradient', 'tophat', 'bottomhat']:
|
||
# 这些操作可能产生负值,保持浮点类型
|
||
return result
|
||
else:
|
||
return self._convert_to_original_type(result, band_data.dtype)
|
||
|
||
def _create_structuring_element(self, shape: str, size: int) -> np.ndarray:
|
||
"""创建结构元素"""
|
||
if shape == 'disk':
|
||
# 创建圆形结构元素
|
||
radius = size // 2
|
||
y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
|
||
mask = x*x + y*y <= radius*radius
|
||
selem = np.zeros((2*radius+1, 2*radius+1), dtype=bool)
|
||
selem[mask] = True
|
||
elif shape == 'square':
|
||
# 创建方形结构元素
|
||
selem = np.ones((size, size), dtype=bool)
|
||
elif shape == 'rectangle':
|
||
# 创建矩形结构元素
|
||
selem = np.ones((size, size*2), dtype=bool)
|
||
elif shape == 'diamond':
|
||
# 创建菱形结构元素
|
||
selem = morphology.diamond(size)
|
||
else:
|
||
raise ValueError(f"不支持的结构元素形状: {shape}")
|
||
|
||
return selem
|
||
|
||
def _morphological_reconstruction(self, marker: np.ndarray, mask: np.ndarray,
|
||
selem: np.ndarray) -> np.ndarray:
|
||
"""
|
||
形态学重建(基于膨胀)
|
||
|
||
Parameters:
|
||
-----------
|
||
marker : np.ndarray
|
||
标记图像
|
||
mask : np.ndarray
|
||
掩模图像
|
||
selem : np.ndarray
|
||
结构元素
|
||
|
||
Returns:
|
||
--------
|
||
reconstructed : np.ndarray
|
||
重建后的图像
|
||
"""
|
||
# 确保标记图像不超过掩模图像
|
||
marker = np.minimum(marker, mask)
|
||
|
||
# 迭代膨胀直到收敛
|
||
prev_marker = np.zeros_like(marker)
|
||
while not np.array_equal(marker, prev_marker):
|
||
prev_marker = marker.copy()
|
||
# 条件膨胀:在掩模限制下膨胀
|
||
dilated = ndimage.grey_dilation(marker, footprint=selem)
|
||
marker = np.minimum(dilated, mask)
|
||
|
||
return marker
|
||
|
||
def _convert_to_original_type(self, data: np.ndarray, original_dtype: np.dtype) -> np.ndarray:
|
||
"""将数据转换回原始数据类型"""
|
||
if np.issubdtype(original_dtype, np.integer):
|
||
# 对于整数类型,进行裁剪和取整
|
||
data = np.clip(data, np.iinfo(original_dtype).min, np.iinfo(original_dtype).max)
|
||
return data.astype(original_dtype)
|
||
else:
|
||
# 对于浮点类型,直接转换
|
||
return data.astype(original_dtype)
|
||
|
||
def apply_to_multiple_bands(self, band_indices: List[int], operation: str,
|
||
se_shape: str = 'disk', se_size: int = 3) -> Dict[int, np.ndarray]:
|
||
"""
|
||
对多个波段应用形态学操作
|
||
|
||
Parameters:
|
||
-----------
|
||
band_indices : List[int]
|
||
波段索引列表
|
||
operation : str
|
||
形态学操作类型
|
||
se_shape : str
|
||
结构元素形状
|
||
se_size : int
|
||
结构元素大小
|
||
|
||
Returns:
|
||
--------
|
||
results : Dict[int, np.ndarray]
|
||
每个波段的处理结果
|
||
"""
|
||
results = {}
|
||
|
||
for band_idx in band_indices:
|
||
print(f"处理波段 {band_idx}...")
|
||
band_data = self.extract_band(band_idx)
|
||
# extract_band 现在应该总是返回2D数据,但为了安全起见检查一下
|
||
if len(band_data.shape) != 2:
|
||
raise ValueError(f"波段 {band_idx} 数据不是2D的,形状: {band_data.shape}")
|
||
result = self.apply_morphology_operation(band_data, operation, se_shape, se_size)
|
||
results[band_idx] = result
|
||
|
||
return results
|
||
|
||
def save_as_envi(self, data: np.ndarray, output_path: str,
|
||
description: str = "形态学处理结果") -> None:
|
||
"""
|
||
保存为ENVI格式的dat和hdr文件
|
||
|
||
Parameters:
|
||
-----------
|
||
data : np.ndarray
|
||
要保存的数据(单波段2D数组)
|
||
output_path : str
|
||
输出文件路径(不含扩展名)
|
||
description : str
|
||
图像描述
|
||
"""
|
||
try:
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
|
||
# 确保输出文件扩展名为.dat
|
||
if not output_path.lower().endswith('.dat'):
|
||
dat_path = output_path + '.dat'
|
||
else:
|
||
dat_path = output_path
|
||
|
||
# 保存为二进制dat文件
|
||
data.tofile(dat_path)
|
||
|
||
# 创建对应的hdr文件
|
||
hdr_path = dat_path.rsplit('.', 1)[0] + '.hdr'
|
||
self._create_morphology_hdr_file(hdr_path, data.shape, description, data.dtype)
|
||
|
||
print(f"ENVI格式已保存:")
|
||
print(f" 数据文件: {dat_path}")
|
||
print(f" 头文件: {hdr_path}")
|
||
print(f" 数据形状: {data.shape}, 数据类型: {data.dtype}")
|
||
|
||
except Exception as e:
|
||
raise IOError(f"保存ENVI文件失败: {e}")
|
||
|
||
def _create_morphology_hdr_file(self, hdr_path: str, data_shape: Tuple,
|
||
description: str, data_dtype: np.dtype = None) -> None:
|
||
"""
|
||
创建形态学处理结果的ENVI头文件
|
||
|
||
Parameters:
|
||
-----------
|
||
hdr_path : str
|
||
头文件路径
|
||
data_shape : Tuple
|
||
数据形状
|
||
description : str
|
||
图像描述
|
||
data_dtype : np.dtype, optional
|
||
数据类型,如果为None则使用默认值
|
||
"""
|
||
try:
|
||
# 从原始头文件获取基本信息(如果有的话)
|
||
if self.header is not None:
|
||
# 复制原始头文件的关键信息
|
||
samples = self.header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
|
||
lines = self.header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
|
||
bands = 1 # 形态学处理结果是单波段
|
||
interleave = 'bsq'
|
||
|
||
# 根据数据类型确定ENVI数据类型
|
||
if data_dtype is not None:
|
||
dtype = data_dtype
|
||
else:
|
||
dtype = np.float32 # 默认类型
|
||
|
||
if np.issubdtype(dtype, np.float32):
|
||
data_type = 4 # float32
|
||
elif np.issubdtype(dtype, np.float64):
|
||
data_type = 5 # float64
|
||
elif np.issubdtype(dtype, np.int32):
|
||
data_type = 3 # int32
|
||
elif np.issubdtype(dtype, np.int16):
|
||
data_type = 2 # int16
|
||
elif np.issubdtype(dtype, np.uint16):
|
||
data_type = 12 # uint16
|
||
else:
|
||
data_type = 4 # 默认float32
|
||
|
||
byte_order = self.header.get('byte order', 0)
|
||
wavelength_units = self.header.get('wavelength units', 'nm')
|
||
else:
|
||
# 默认值(适用于2D形态学处理结果)
|
||
if len(data_shape) == 2:
|
||
lines, samples = data_shape
|
||
else:
|
||
lines = data_shape[0]
|
||
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
|
||
bands = 1
|
||
interleave = 'bsq'
|
||
data_type = 4 # float32(形态学处理通常涉及浮点数)
|
||
byte_order = 0
|
||
wavelength_units = 'Unknown'
|
||
|
||
# 写入hdr文件
|
||
with open(hdr_path, 'w') as f:
|
||
f.write("ENVI\n")
|
||
f.write("description = {\n")
|
||
f.write(f" {description}\n")
|
||
f.write(" 单波段形态学处理结果\n")
|
||
f.write("}\n")
|
||
f.write(f"samples = {samples}\n")
|
||
f.write(f"lines = {lines}\n")
|
||
f.write(f"bands = {bands}\n")
|
||
f.write(f"header offset = 0\n")
|
||
f.write(f"file type = ENVI Standard\n")
|
||
f.write(f"data type = {data_type}\n")
|
||
f.write(f"interleave = {interleave}\n")
|
||
f.write(f"byte order = {byte_order}\n")
|
||
|
||
# 如果有波长信息,添加虚拟波长
|
||
if self.header and 'wavelength' in self.header:
|
||
f.write("wavelength = {\n")
|
||
f.write(" 形态学处理结果\n")
|
||
f.write("}\n")
|
||
f.write(f"wavelength units = {wavelength_units}\n")
|
||
|
||
except Exception as e:
|
||
raise IOError(f"创建HDR文件失败: {e}")
|
||
|
||
def visualize_results(self, original_band: np.ndarray,
|
||
processed_band: np.ndarray,
|
||
operation_name: str,
|
||
save_path: Optional[str] = None) -> None:
|
||
"""
|
||
可视化原始波段和处理后的波段
|
||
|
||
Parameters:
|
||
-----------
|
||
original_band : np.ndarray
|
||
原始波段数据
|
||
processed_band : np.ndarray
|
||
处理后的波段数据
|
||
operation_name : str
|
||
操作名称
|
||
save_path : str, optional
|
||
保存图像路径
|
||
"""
|
||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||
|
||
# 原始图像
|
||
im1 = axes[0].imshow(original_band, cmap='gray')
|
||
axes[0].set_title('原始波段')
|
||
axes[0].axis('off')
|
||
plt.colorbar(im1, ax=axes[0], shrink=0.8)
|
||
|
||
# 处理后的图像
|
||
im2 = axes[1].imshow(processed_band, cmap='gray')
|
||
axes[1].set_title(f'{operation_name} 结果')
|
||
axes[1].axis('off')
|
||
plt.colorbar(im2, ax=axes[1], shrink=0.8)
|
||
|
||
# 差异图像
|
||
diff = processed_band.astype(np.float32) - original_band.astype(np.float32)
|
||
im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
|
||
axes[2].set_title('差异 (处理后 - 原始)')
|
||
axes[2].axis('off')
|
||
plt.colorbar(im3, ax=axes[2], shrink=0.8)
|
||
|
||
plt.suptitle(f'形态学操作: {operation_name}', fontsize=16)
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f"可视化结果已保存到: {save_path}")
|
||
|
||
plt.show()
|
||
|
||
|
||
def main():
|
||
"""主函数:高光谱图像形态学处理"""
|
||
import argparse
|
||
|
||
# 支持的形态学操作列表
|
||
morph_operations = [
|
||
'dilation', 'erosion', 'opening', 'closing',
|
||
'gradient', 'tophat', 'bottomhat', 'reconstruction'
|
||
]
|
||
|
||
# 支持的结构元素形状
|
||
se_shapes = ['disk', 'square', 'rectangle', 'diamond']
|
||
|
||
parser = argparse.ArgumentParser(description='高光谱图像形态学处理工具')
|
||
parser.add_argument('input_file', help='输入高光谱图像文件路径')
|
||
parser.add_argument('--format', '-f', default='auto',
|
||
choices=['auto', 'bil', 'bip', 'bsq', 'dat'],
|
||
help='图像格式 (默认: auto)')
|
||
parser.add_argument('--band', '-b', type=int, default=0,
|
||
help='要处理的波段索引 (默认: 0)')
|
||
parser.add_argument('--bands', '-B', type=str, default=None,
|
||
help='多个波段索引,用逗号分隔,如 "0,10,20"')
|
||
parser.add_argument('--operation', '-o', default='dilation',
|
||
choices=morph_operations,
|
||
help=f'形态学操作类型 (默认: dilation)')
|
||
parser.add_argument('--se_shape', '-s', default='disk',
|
||
choices=se_shapes,
|
||
help='结构元素形状 (默认: disk)')
|
||
parser.add_argument('--se_size', '-S', type=int, default=3,
|
||
help='结构元素大小 (默认: 3)')
|
||
parser.add_argument('--output_dir', '-d', default='output',
|
||
help='输出目录 (默认: output)')
|
||
parser.add_argument('--output_name', '-n', default='morphology_result',
|
||
help='输出文件名 (不含扩展名) (默认: morphology_result)')
|
||
parser.add_argument('--visualize', '-v', action='store_true',
|
||
help='是否生成可视化结果')
|
||
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
# 初始化处理器
|
||
processor = HyperspectralMorphologyProcessor()
|
||
|
||
# 确定格式
|
||
format_type = None if args.format == 'auto' else args.format
|
||
|
||
# 加载图像
|
||
print(f"加载图像: {args.input_file}")
|
||
data, header = processor.load_hyperspectral_image(args.input_file, format_type)
|
||
|
||
# 确定要处理的波段
|
||
if args.bands:
|
||
band_indices = [int(b.strip()) for b in args.bands.split(',')]
|
||
print(f"处理波段: {band_indices}")
|
||
single_band = False
|
||
else:
|
||
band_indices = [args.band]
|
||
print(f"处理波段: {args.band}")
|
||
single_band = True
|
||
|
||
# 应用形态学操作
|
||
if single_band:
|
||
# 单波段处理
|
||
band_data = processor.extract_band(args.band)
|
||
print(f"主函数中波段数据形状: {band_data.shape}")
|
||
result = processor.apply_morphology_operation(
|
||
band_data, args.operation, args.se_shape, args.se_size
|
||
)
|
||
|
||
# 保存结果
|
||
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{args.band}")
|
||
processor.save_as_envi(result, output_path,
|
||
f"{args.operation.capitalize()} 处理结果 - 波段 {args.band}")
|
||
|
||
# 可视化
|
||
if args.visualize:
|
||
vis_path = os.path.join(args.output_dir, f"visualization_band{args.band}.png")
|
||
processor.visualize_results(band_data, result, args.operation, vis_path)
|
||
|
||
# 打印统计信息
|
||
print(f"\n=== 统计信息 ===")
|
||
print(f"原始波段 {args.band}: min={band_data.min():.4f}, max={band_data.max():.4f}, "
|
||
f"mean={band_data.mean():.4f}")
|
||
print(f"处理后: min={result.min():.4f}, max={result.max():.4f}, "
|
||
f"mean={result.mean():.4f}")
|
||
|
||
else:
|
||
# 多波段处理
|
||
results = processor.apply_to_multiple_bands(
|
||
band_indices, args.operation, args.se_shape, args.se_size
|
||
)
|
||
|
||
# 保存每个波段的结果
|
||
for band_idx, result in results.items():
|
||
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{band_idx}")
|
||
processor.save_as_envi(result, output_path,
|
||
f"{args.operation.capitalize()} 处理结果 - 波段 {band_idx}")
|
||
|
||
# 如果只有少量波段,可以创建组合可视化
|
||
if len(band_indices) <= 4 and args.visualize:
|
||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||
axes = axes.flatten()
|
||
|
||
for idx, band_idx in enumerate(band_indices[:4]):
|
||
original_band = processor.extract_band(band_idx)
|
||
processed_band = results[band_idx]
|
||
|
||
axes[idx].imshow(processed_band, cmap='gray')
|
||
axes[idx].set_title(f'波段 {band_idx} - {args.operation}')
|
||
axes[idx].axis('off')
|
||
|
||
plt.suptitle(f'多波段形态学处理: {args.operation}', fontsize=16)
|
||
plt.tight_layout()
|
||
|
||
vis_path = os.path.join(args.output_dir, "multi_band_visualization.png")
|
||
plt.savefig(vis_path, dpi=300, bbox_inches='tight')
|
||
print(f"多波段可视化结果已保存到: {vis_path}")
|
||
plt.show()
|
||
|
||
print(f"\n✓ 形态学处理完成!")
|
||
print(f"输出目录: {args.output_dir}")
|
||
|
||
return 0
|
||
|
||
except Exception as e:
|
||
print(f"✗ 处理失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
|
||
# 使用示例函数
|
||
def run_example():
|
||
"""运行示例"""
|
||
# 示例1:对单个波段进行膨胀操作
|
||
processor = HyperspectralMorphologyProcessor()
|
||
|
||
# 加载图像(请替换为您的实际文件路径)
|
||
input_file = "path/to/your/hyperspectral_image.hdr" # 或 .bil, .bip, .bsq, .dat
|
||
data, header = processor.load_hyperspectral_image(input_file, format_type='bip')
|
||
|
||
# 提取波段(例如,提取第50个波段)
|
||
band_idx = 50
|
||
band_data = processor.extract_band(band_idx)
|
||
|
||
# 应用不同的形态学操作
|
||
operations = ['dilation', 'erosion', 'opening', 'closing',
|
||
'gradient', 'tophat', 'bottomhat', 'reconstruction']
|
||
|
||
results = {}
|
||
for operation in operations:
|
||
print(f"\n应用 {operation} 操作...")
|
||
try:
|
||
result = processor.apply_morphology_operation(
|
||
band_data, operation, se_shape='disk', se_size=3
|
||
)
|
||
results[operation] = result
|
||
|
||
# 保存结果
|
||
output_path = f"output/{operation}_band{band_idx}"
|
||
processor.save_as_envi(result, output_path,
|
||
f"{operation.capitalize()} 处理结果")
|
||
|
||
# 可视化
|
||
processor.visualize_results(band_data, result, operation,
|
||
save_path=f"output/visualization_{operation}.png")
|
||
|
||
except Exception as e:
|
||
print(f"操作 {operation} 失败: {e}")
|
||
|
||
return results
|
||
|
||
|
||
@dataclass
|
||
class MorphologicalFilterConfig:
|
||
"""形态学滤波配置类"""
|
||
input_path: str
|
||
operation: str
|
||
band_index: int
|
||
kernel_size: int
|
||
output_dir: str
|
||
se_shape: str = 'disk'
|
||
se_size: int = 3
|
||
format_type: str = 'auto'
|
||
|
||
|
||
if __name__ == '__main__':
|
||
exit(main())
|