增加模块;增加主调用命令
This commit is contained in:
751
fliter_method/Smooth_filter.py
Normal file
751
fliter_method/Smooth_filter.py
Normal file
@ -0,0 +1,751 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
from typing import Tuple, Optional, Dict, Union
|
||||
from scipy import ndimage
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
try:
|
||||
import spectral as spy
|
||||
from spectral import envi
|
||||
HAS_SPECTRAL = True
|
||||
except ImportError:
|
||||
HAS_SPECTRAL = False
|
||||
print("警告: 未安装spectral库,只能处理RGB图像")
|
||||
|
||||
|
||||
class HyperspectralImageFilter:
|
||||
"""高光谱图像平滑滤波器
|
||||
|
||||
使用示例:
|
||||
# 创建滤波器实例
|
||||
filter_obj = HyperspectralImageFilter()
|
||||
|
||||
# 处理图像 - 均值滤波
|
||||
dat_path, hdr_path = filter_obj.process_image(
|
||||
input_path='input.hdr',
|
||||
output_path='output_mean',
|
||||
filter_type='mean',
|
||||
kernel_size=5
|
||||
)
|
||||
|
||||
# 处理图像 - 高斯滤波
|
||||
dat_path, hdr_path = filter_obj.process_image(
|
||||
input_path='input.hdr',
|
||||
output_path='output_gaussian',
|
||||
filter_type='gaussian',
|
||||
kernel_size=3,
|
||||
sigma=2.0
|
||||
)
|
||||
|
||||
# 处理图像 - 双边滤波
|
||||
dat_path, hdr_path = filter_obj.process_image(
|
||||
input_path='input.hdr',
|
||||
output_path='output_bilateral',
|
||||
filter_type='bilateral',
|
||||
kernel_size=7,
|
||||
sigma_color=50.0,
|
||||
sigma_space=50.0
|
||||
)
|
||||
|
||||
# 或者直接使用apply_filter方法
|
||||
data, header = filter_obj.load_image('input.hdr')
|
||||
filtered_data = filter_obj.apply_filter('mean', kernel_size=5)
|
||||
dat_path, hdr_path = filter_obj.save_envi('output', filtered_data, 'mean', header)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.header = None
|
||||
self.wavelengths = None
|
||||
self.is_hyperspectral = False
|
||||
self._selected_band_index = None # 记录选择的波段索引
|
||||
self._original_dtype = None # 记录原始数据类型
|
||||
self._original_range = None # 记录原始数据范围
|
||||
|
||||
def load_image(self, image_path: str, band_index: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
|
||||
"""
|
||||
加载图像文件
|
||||
|
||||
Parameters:
|
||||
image_path: 图像文件路径 (.hdr/.dat 或 .jpg/.png/.tif)
|
||||
band_index: 对于高光谱图像,指定要处理的波段索引 (None表示使用所有波段)
|
||||
|
||||
Returns:
|
||||
处理后的图像数据和元数据
|
||||
"""
|
||||
try:
|
||||
# 检查文件扩展名
|
||||
file_ext = os.path.splitext(image_path)[1].lower()
|
||||
|
||||
if file_ext in ['.hdr']:
|
||||
# ENVI高光谱图像
|
||||
if not HAS_SPECTRAL:
|
||||
raise ImportError("需要安装spectral库来处理ENVI格式文件")
|
||||
|
||||
return self._load_envi_image(image_path, band_index)
|
||||
|
||||
elif file_ext in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']:
|
||||
# RGB图像
|
||||
return self._load_rgb_image(image_path)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {file_ext}")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载图像失败: {e}")
|
||||
|
||||
def _load_envi_image(self, hdr_path: str, band_index: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
|
||||
"""加载ENVI格式高光谱图像"""
|
||||
try:
|
||||
# 读取ENVI文件
|
||||
img = envi.open(hdr_path)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 提取波长信息
|
||||
wavelengths = None
|
||||
if 'wavelength' in header:
|
||||
try:
|
||||
wavelengths = np.array([float(w) for w in header['wavelength']])
|
||||
except:
|
||||
wavelengths = None
|
||||
|
||||
# 如果指定了波段,选择特定波段
|
||||
if band_index is not None:
|
||||
if band_index < 0 or band_index >= data.shape[2]:
|
||||
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {data.shape[2]-1}]")
|
||||
|
||||
print(f"原始数据形状: {data.shape}")
|
||||
# 记录选择的波段索引,用于头文件生成
|
||||
self._selected_band_index = band_index
|
||||
|
||||
data = data[:, :, band_index]
|
||||
print(f"选择波段 {band_index} 后的数据形状: {data.shape}")
|
||||
|
||||
# 如果结果是形状 (H, W, 1),压缩为 (H, W)
|
||||
if len(data.shape) == 3 and data.shape[2] == 1:
|
||||
data = np.squeeze(data, axis=2) # 压缩最后一个维度
|
||||
print(f"压缩单波段数据为2D: {data.shape}")
|
||||
|
||||
print(f"选择波段 {band_index} 进行处理")
|
||||
if wavelengths is not None:
|
||||
print(f"波长: {wavelengths[band_index]:.1f} nm")
|
||||
# 选择了特定波段,数据变为2D
|
||||
self.is_hyperspectral = False
|
||||
else:
|
||||
print("处理所有波段")
|
||||
self.is_hyperspectral = True
|
||||
self._selected_band_index = None
|
||||
|
||||
# 记录原始数据类型和范围
|
||||
self._original_dtype = data.dtype
|
||||
if data.size > 0: # 确保数据不为空
|
||||
self._original_range = (data.min(), data.max())
|
||||
print(f"原始数据范围: {self._original_range[0]} - {self._original_range[1]}")
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.wavelengths = wavelengths
|
||||
|
||||
print(f"成功加载ENVI图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||||
if wavelengths is not None and band_index is not None:
|
||||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载ENVI图像失败: {e}")
|
||||
|
||||
def _load_rgb_image(self, image_path: str) -> Tuple[np.ndarray, Dict]:
|
||||
"""加载RGB图像"""
|
||||
try:
|
||||
# 使用OpenCV读取图像
|
||||
data = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if data is None:
|
||||
raise ValueError(f"无法读取图像文件: {image_path}")
|
||||
|
||||
# 转换为RGB格式 (OpenCV默认是BGR)
|
||||
if len(data.shape) == 3 and data.shape[2] == 3:
|
||||
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 记录原始数据类型和范围
|
||||
self._original_dtype = data.dtype
|
||||
if data.size > 0:
|
||||
self._original_range = (data.min(), data.max())
|
||||
print(f"原始数据范围: {self._original_range[0]} - {self._original_range[1]}")
|
||||
|
||||
# 创建基本的元数据
|
||||
header = {
|
||||
'samples': data.shape[1],
|
||||
'lines': data.shape[0],
|
||||
'bands': data.shape[2] if len(data.shape) == 3 else 1,
|
||||
'data_type': self._get_envi_data_type(data.dtype),
|
||||
'interleave': 'bsq',
|
||||
'byte_order': 0
|
||||
}
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.wavelengths = None
|
||||
self.is_hyperspectral = False
|
||||
|
||||
print(f"成功加载RGB图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载RGB图像失败: {e}")
|
||||
|
||||
def _get_envi_data_type(self, dtype: np.dtype) -> int:
|
||||
"""获取ENVI数据类型代码"""
|
||||
if dtype == np.uint8:
|
||||
return 1
|
||||
elif dtype == np.int16:
|
||||
return 2
|
||||
elif dtype == np.int32:
|
||||
return 3
|
||||
elif dtype == np.float32:
|
||||
return 4
|
||||
elif dtype == np.float64:
|
||||
return 5
|
||||
elif dtype == np.uint16:
|
||||
return 12
|
||||
elif dtype == np.uint32:
|
||||
return 13
|
||||
else:
|
||||
return 4 # 默认float32
|
||||
|
||||
def apply_filter(self, filter_type: str, **kwargs) -> np.ndarray:
|
||||
"""
|
||||
应用指定的滤波器
|
||||
|
||||
Parameters:
|
||||
filter_type: 滤波器类型 ('mean', 'median', 'gaussian', 'bilateral')
|
||||
**kwargs: 各滤波器的超参数
|
||||
- kernel_size: 内核大小 (奇数,默认3)
|
||||
- sigma: 高斯滤波的标准差 (默认1.0)
|
||||
- sigma_color: 双边滤波的颜色空间标准差 (默认75.0)
|
||||
- sigma_space: 双边滤波的空间标准差 (默认75.0)
|
||||
|
||||
Returns:
|
||||
滤波后的图像数据
|
||||
"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先加载图像数据")
|
||||
|
||||
# 获取参数,设置默认值
|
||||
kernel_size = kwargs.get('kernel_size', 3)
|
||||
sigma = kwargs.get('sigma', 1.0)
|
||||
sigma_color = kwargs.get('sigma_color', 75.0)
|
||||
sigma_space = kwargs.get('sigma_space', 75.0)
|
||||
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1 # 确保为奇数
|
||||
|
||||
print(f"应用{filter_type}滤波器,参数: kernel_size={kernel_size}")
|
||||
|
||||
if filter_type.lower() == 'mean':
|
||||
return self._mean_filter(kernel_size)
|
||||
elif filter_type.lower() == 'median':
|
||||
return self._median_filter(kernel_size)
|
||||
elif filter_type.lower() == 'gaussian':
|
||||
print(f" sigma={sigma}")
|
||||
return self._gaussian_filter(kernel_size, sigma)
|
||||
elif filter_type.lower() == 'bilateral':
|
||||
print(f" sigma_color={sigma_color}, sigma_space={sigma_space}")
|
||||
return self._bilateral_filter(kernel_size, sigma_color, sigma_space)
|
||||
else:
|
||||
raise ValueError(f"不支持的滤波器类型: {filter_type}")
|
||||
|
||||
def _mean_filter(self, kernel_size: int) -> np.ndarray:
|
||||
"""均值滤波"""
|
||||
try:
|
||||
print(f"均值滤波开始 - 数据形状: {self.data.shape}, is_hyperspectral: {self.is_hyperspectral}")
|
||||
|
||||
# 创建均值内核
|
||||
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) / (kernel_size * kernel_size)
|
||||
|
||||
# 检查数据维度
|
||||
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
|
||||
# 多波段高光谱图像 - 对每个波段分别处理
|
||||
filtered_data = np.zeros_like(self.data, dtype=np.float32)
|
||||
for band in range(self.data.shape[2]):
|
||||
filtered_data[:, :, band] = ndimage.convolve(self.data[:, :, band].astype(np.float32), kernel)
|
||||
print(f"均值滤波完成 - 处理了 {self.data.shape[2]} 个波段")
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
|
||||
# 单波段3D数据 (H, W, 1) - 压缩并处理
|
||||
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
|
||||
filtered_data = ndimage.convolve(data_2d, kernel)
|
||||
# 保持输出为3D形状以保持一致性
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
print(f"均值滤波完成 - 单波段3D数据 (H,W,1) -> (H,W,1)")
|
||||
else:
|
||||
# 2D图像或单波段图像
|
||||
print(f"应用2D均值滤波到形状 {self.data.shape} 的数据")
|
||||
filtered_data = ndimage.convolve(self.data.astype(np.float32), kernel)
|
||||
print(f"2D均值滤波完成 - 输出形状: {filtered_data.shape}")
|
||||
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"均值滤波失败: {e}")
|
||||
|
||||
def _median_filter(self, kernel_size: int) -> np.ndarray:
|
||||
"""中值滤波"""
|
||||
try:
|
||||
# 检查数据维度
|
||||
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
|
||||
# 多波段高光谱图像 - 对每个波段分别处理
|
||||
filtered_data = np.zeros_like(self.data, dtype=self.data.dtype)
|
||||
for band in range(self.data.shape[2]):
|
||||
filtered_data[:, :, band] = ndimage.median_filter(self.data[:, :, band], size=kernel_size)
|
||||
print(f"中值滤波完成 - 处理了 {self.data.shape[2]} 个波段")
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
|
||||
# 单波段3D数据 (H, W, 1) - 压缩并处理
|
||||
data_2d = np.squeeze(self.data, axis=2)
|
||||
filtered_data = ndimage.median_filter(data_2d, size=kernel_size)
|
||||
# 保持输出为3D形状以保持一致性
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
print(f"中值滤波完成 - 单波段3D数据 (H,W,1) -> (H,W,1)")
|
||||
else:
|
||||
# 2D图像或单波段图像
|
||||
filtered_data = ndimage.median_filter(self.data, size=kernel_size)
|
||||
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"中值滤波失败: {e}")
|
||||
|
||||
def _gaussian_filter(self, kernel_size: int, sigma: float) -> np.ndarray:
|
||||
"""高斯滤波"""
|
||||
try:
|
||||
# 检查数据维度
|
||||
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
|
||||
# 多波段高光谱图像 - 对每个波段分别处理
|
||||
filtered_data = np.zeros_like(self.data, dtype=np.float32)
|
||||
for band in range(self.data.shape[2]):
|
||||
filtered_data[:, :, band] = ndimage.gaussian_filter(
|
||||
self.data[:, :, band].astype(np.float32), sigma=sigma
|
||||
)
|
||||
print(f"高斯滤波完成 (sigma={sigma}) - 处理了 {self.data.shape[2]} 个波段")
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
|
||||
# 单波段3D数据 (H, W, 1) - 压缩并处理
|
||||
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
|
||||
filtered_data = ndimage.gaussian_filter(data_2d, sigma=sigma)
|
||||
# 保持输出为3D形状以保持一致性
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
print(f"高斯滤波完成 (sigma={sigma}) - 单波段3D数据 (H,W,1) -> (H,W,1)")
|
||||
else:
|
||||
# 2D图像或单波段图像
|
||||
filtered_data = ndimage.gaussian_filter(self.data.astype(np.float32), sigma=sigma)
|
||||
print(f"高斯滤波完成 (sigma={sigma})")
|
||||
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"高斯滤波失败: {e}")
|
||||
|
||||
def _bilateral_filter(self, kernel_size: int, sigma_color: float, sigma_space: float) -> np.ndarray:
|
||||
"""双边滤波"""
|
||||
try:
|
||||
# 双边滤波主要用于RGB图像或单波段图像
|
||||
# 对于高光谱数据,我们可以对每个波段分别应用双边滤波
|
||||
|
||||
if len(self.data.shape) == 3 and self.data.shape[2] > 1 and self.is_hyperspectral:
|
||||
# 多波段高光谱图像 - 对每个波段分别处理
|
||||
filtered_data = np.zeros_like(self.data, dtype=np.float32)
|
||||
for band in range(self.data.shape[2]):
|
||||
band_data = self.data[:, :, band].astype(np.float32)
|
||||
# 归一化到0-255范围进行双边滤波
|
||||
if band_data.max() > band_data.min():
|
||||
band_norm = ((band_data - band_data.min()) / (band_data.max() - band_data.min()) * 255).astype(np.uint8)
|
||||
filtered_norm = cv2.bilateralFilter(band_norm, kernel_size, sigma_color, sigma_space)
|
||||
# 恢复原始范围
|
||||
filtered_data[:, :, band] = filtered_norm.astype(np.float32) / 255 * (band_data.max() - band_data.min()) + band_data.min()
|
||||
else:
|
||||
filtered_data[:, :, band] = band_data
|
||||
|
||||
print(f"双边滤波完成 - 处理了 {self.data.shape[2]} 个波段")
|
||||
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 3:
|
||||
# RGB图像 - 直接使用OpenCV的双边滤波
|
||||
filtered_data = cv2.bilateralFilter(self.data, kernel_size, sigma_color, sigma_space)
|
||||
print("RGB双边滤波完成")
|
||||
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
|
||||
# 单波段3D数据 (H, W, 1) - 压缩并处理
|
||||
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
|
||||
if data_2d.max() > data_2d.min():
|
||||
data_norm = ((data_2d - data_2d.min()) / (data_2d.max() - data_2d.min()) * 255).astype(np.uint8)
|
||||
filtered_norm = cv2.bilateralFilter(data_norm, kernel_size, sigma_color, sigma_space)
|
||||
filtered_data = filtered_norm.astype(np.float32) / 255 * (data_2d.max() - data_2d.min()) + data_2d.min()
|
||||
# 保持输出为3D形状以保持一致性
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
else:
|
||||
filtered_data = self.data.copy()
|
||||
print("单波段双边滤波完成 (3D -> 3D)")
|
||||
|
||||
else:
|
||||
# 2D单波段图像 - 转换为uint8进行处理
|
||||
data_float = self.data.astype(np.float32)
|
||||
if data_float.max() > data_float.min():
|
||||
data_norm = ((data_float - data_float.min()) / (data_float.max() - data_float.min()) * 255).astype(np.uint8)
|
||||
filtered_norm = cv2.bilateralFilter(data_norm, kernel_size, sigma_color, sigma_space)
|
||||
filtered_data = filtered_norm.astype(np.float32) / 255 * (data_float.max() - data_float.min()) + data_float.min()
|
||||
else:
|
||||
filtered_data = data_float
|
||||
print("单波段双边滤波完成")
|
||||
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"双边滤波失败: {e}")
|
||||
|
||||
def save_envi(self, output_path: str, filtered_data: np.ndarray, filter_type: str,
|
||||
original_header: Dict = None) -> Tuple[str, str]:
|
||||
"""
|
||||
保存滤波结果为ENVI格式
|
||||
|
||||
Parameters:
|
||||
output_path: 输出文件路径(不含扩展名)
|
||||
filtered_data: 滤波后的数据
|
||||
filter_type: 滤波器类型,用于文件名
|
||||
original_header: 原始图像的头文件信息
|
||||
|
||||
Returns:
|
||||
数据文件路径和头文件路径
|
||||
"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 生成文件名
|
||||
base_name = os.path.basename(output_path)
|
||||
dat_path = f"{output_path}_{filter_type}.dat"
|
||||
hdr_path = f"{output_path}_{filter_type}.hdr"
|
||||
|
||||
# 保存数据文件 - 保持与原始数据一致的类型和范围
|
||||
print(f"滤波前数据范围: {filtered_data.min():.2f} - {filtered_data.max():.2f}")
|
||||
print(f"原始数据类型: {self._original_dtype}, 范围: {self._original_range}")
|
||||
|
||||
# 根据原始数据类型决定保存格式
|
||||
if self._original_dtype is not None:
|
||||
# 如果原始数据是整数类型,尝试保持整数格式
|
||||
if np.issubdtype(self._original_dtype, np.integer):
|
||||
# 对于整数类型的原始数据,检查滤波后的数据是否仍在合理范围内
|
||||
if self._original_range is not None:
|
||||
orig_min, orig_max = self._original_range
|
||||
filtered_min, filtered_max = filtered_data.min(), filtered_data.max()
|
||||
|
||||
# 如果滤波后的数据范围仍在原始范围附近,使用原始数据类型
|
||||
if (filtered_min >= orig_min - 100) and (filtered_max <= orig_max + 100):
|
||||
print(f"保持原始整数数据类型: {self._original_dtype}")
|
||||
filtered_data = np.clip(filtered_data, orig_min, orig_max).astype(self._original_dtype)
|
||||
else:
|
||||
print(f"数据范围变化较大,使用float32保存")
|
||||
filtered_data = filtered_data.astype(np.float32)
|
||||
else:
|
||||
filtered_data = filtered_data.astype(np.float32)
|
||||
else:
|
||||
# 原始数据就是浮点数,保持float32
|
||||
filtered_data = filtered_data.astype(np.float32)
|
||||
else:
|
||||
# 默认使用float32
|
||||
filtered_data = filtered_data.astype(np.float32)
|
||||
|
||||
print(f"保存数据类型: {filtered_data.dtype}, 范围: {filtered_data.min():.2f} - {filtered_data.max():.2f}")
|
||||
|
||||
# 确保数据格式适合ENVI标准
|
||||
if len(filtered_data.shape) == 2:
|
||||
# 2D数据 (H, W) -> 3D数据 (H, W, 1)
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
print(f"转换2D数据为3D格式以符合ENVI标准: {filtered_data.shape}")
|
||||
|
||||
# ENVI使用BSQ格式存储
|
||||
filtered_data.tofile(dat_path)
|
||||
print(f"数据文件已保存: {dat_path}")
|
||||
|
||||
# 生成头文件
|
||||
self._save_envi_header(hdr_path, filtered_data, filter_type, original_header)
|
||||
print(f"头文件已保存: {hdr_path}")
|
||||
|
||||
return dat_path, hdr_path
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"保存ENVI文件失败: {e}")
|
||||
|
||||
def _save_envi_header(self, hdr_path: str, data: np.ndarray, filter_type: str,
|
||||
original_header: Dict = None):
|
||||
"""保存ENVI标准格式头文件"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||||
f.write("ENVI\n")
|
||||
|
||||
# 描述信息 - 包含处理时间和描述
|
||||
current_time = datetime.now().strftime("%a %b %d %H:%M:%S %Y")
|
||||
f.write("description = {\n")
|
||||
f.write(f" Convolution Result [{current_time}]\n")
|
||||
f.write("}\n")
|
||||
|
||||
# 图像尺寸
|
||||
if len(data.shape) == 3:
|
||||
samples, lines, bands = data.shape[1], data.shape[0], data.shape[2]
|
||||
else:
|
||||
samples, lines, bands = data.shape[1], data.shape[0], 1
|
||||
|
||||
f.write(f"samples = {samples}\n")
|
||||
f.write(f"lines = {lines}\n")
|
||||
f.write(f"bands = {bands}\n")
|
||||
f.write("header offset = 0\n")
|
||||
f.write("file type = ENVI Standard\n")
|
||||
|
||||
# 数据类型 - 使用实际保存的数据类型
|
||||
if hasattr(self, '_original_dtype') and self._original_dtype is not None:
|
||||
# 优先使用原始数据类型(如果范围合适)
|
||||
data_type = self._get_envi_data_type(self._original_dtype)
|
||||
else:
|
||||
data_type = self._get_envi_data_type(data.dtype)
|
||||
f.write(f"data type = {data_type}\n")
|
||||
|
||||
# 交织格式
|
||||
f.write("interleave = bsq\n")
|
||||
|
||||
# 传感器类型
|
||||
if original_header and 'sensor type' in original_header:
|
||||
f.write(f"sensor type = {original_header['sensor type']}\n")
|
||||
else:
|
||||
f.write("sensor type = Unknown\n")
|
||||
|
||||
# 字节顺序
|
||||
f.write("byte order = 0\n")
|
||||
|
||||
# 波长信息
|
||||
if self.wavelengths is not None and len(self.wavelengths) > 0:
|
||||
f.write("wavelength units = Nanometers\n")
|
||||
if bands == 1 and hasattr(self, '_selected_band_index') and self._selected_band_index is not None:
|
||||
# 单波段情况 - 使用选定波段的波长
|
||||
band_idx = min(self._selected_band_index, len(self.wavelengths) - 1)
|
||||
f.write(f"wavelength = {{\n")
|
||||
f.write(f" {self.wavelengths[band_idx]:.6f}\n")
|
||||
f.write("}\n")
|
||||
else:
|
||||
# 多波段情况 - 写入所有波长
|
||||
f.write("wavelength = {\n")
|
||||
wavelength_str = ", ".join([f"{w:.6f}" for w in self.wavelengths[:bands]])
|
||||
f.write(f" {wavelength_str}\n")
|
||||
f.write("}\n")
|
||||
|
||||
# 反射率比例因子
|
||||
f.write("reflectance scale factor = 10000.000000\n")
|
||||
|
||||
# 从原始头文件复制其他重要参数
|
||||
if original_header:
|
||||
# 增益信息
|
||||
if 'gain' in original_header:
|
||||
f.write(f"gain = {original_header['gain']}\n")
|
||||
|
||||
# 分辨率信息
|
||||
binning_keys = ['sample binning', 'spectral binning', 'line binning']
|
||||
for key in binning_keys:
|
||||
if key in original_header:
|
||||
f.write(f"{key} = {original_header[key]}\n")
|
||||
|
||||
# 快门和帧率信息
|
||||
if 'shutter' in original_header:
|
||||
f.write(f"shutter = {original_header['shutter']}\n")
|
||||
if 'framerate' in original_header:
|
||||
f.write(f"framerate = {original_header['framerate']}\n")
|
||||
|
||||
# 设备序列号
|
||||
if 'imager serial number' in original_header:
|
||||
f.write(f"imager serial number = {original_header['imager serial number']}\n")
|
||||
|
||||
# 旋转矩阵
|
||||
if 'rotation' in original_header:
|
||||
rotation = original_header['rotation']
|
||||
f.write(f"rotation = {rotation}\n")
|
||||
|
||||
# 标签
|
||||
if 'label' in original_header:
|
||||
f.write(f"label = {original_header['label']}\n")
|
||||
|
||||
# 处理历史
|
||||
if 'history' in original_header:
|
||||
f.write(f"history = {original_header['history']}\n")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"保存头文件失败: {e}")
|
||||
|
||||
def process_image(self, input_path: str, output_path: str, filter_type: str,
|
||||
band_index: Optional[int] = None, **kwargs) -> Tuple[str, str]:
|
||||
"""
|
||||
完整的图像处理流程
|
||||
|
||||
Parameters:
|
||||
input_path: 输入图像路径
|
||||
output_path: 输出路径(不含扩展名)
|
||||
filter_type: 滤波器类型
|
||||
band_index: 波段索引(可选)
|
||||
**kwargs: 各滤波器的超参数
|
||||
- kernel_size: 内核大小 (奇数,默认3)
|
||||
- sigma: 高斯滤波的标准差 (默认1.0)
|
||||
- sigma_color: 双边滤波的颜色空间标准差 (默认75.0)
|
||||
- sigma_space: 双边滤波的空间标准差 (默认75.0)
|
||||
|
||||
Returns:
|
||||
数据文件路径和头文件路径
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("高光谱图像平滑滤波器")
|
||||
print("=" * 50)
|
||||
|
||||
# 加载图像
|
||||
print(f"加载图像: {input_path}")
|
||||
data, header = self.load_image(input_path, band_index)
|
||||
|
||||
# 应用滤波器
|
||||
print(f"应用{filter_type}滤波...")
|
||||
filtered_data = self.apply_filter(filter_type, **kwargs)
|
||||
|
||||
# 保存结果
|
||||
print(f"保存结果到: {output_path}")
|
||||
dat_path, hdr_path = self.save_envi(output_path, filtered_data, filter_type, header)
|
||||
|
||||
print("=" * 50)
|
||||
print("处理完成!")
|
||||
print(f"数据文件: {dat_path}")
|
||||
print(f"头文件: {hdr_path}")
|
||||
print("=" * 50)
|
||||
|
||||
return dat_path, hdr_path
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:命令行接口"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='高光谱图像平滑滤波工具',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
1. 中值滤波:
|
||||
python Smooth_filter.py input.hdr -f median -b 50 -k 5 -o output.dat
|
||||
|
||||
2. 高斯滤波:
|
||||
python Smooth_filter.py input.hdr -f gaussian -b 30 -k 3 -s 1.5 -o output.dat
|
||||
|
||||
3. 双边滤波:
|
||||
python Smooth_filter.py input.hdr -f bilateral -b 25 -k 5 -sc 75 -ss 75 -o output.dat
|
||||
|
||||
4. 均值滤波:
|
||||
python Smooth_filter.py input.hdr -f mean -b 40 -k 3 -o output.dat
|
||||
|
||||
支持的滤波类型:
|
||||
mean: 均值滤波
|
||||
median: 中值滤波
|
||||
gaussian: 高斯滤波
|
||||
bilateral: 双边滤波
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('input_path', help='输入高光谱图像文件路径')
|
||||
parser.add_argument('-f', '--filter-type', default='gaussian',
|
||||
choices=['mean', 'median', 'gaussian', 'bilateral'],
|
||||
help='滤波类型 (默认: gaussian)')
|
||||
parser.add_argument('-b', '--band-index', type=int, default=0,
|
||||
help='处理的波段索引 (默认: 0)')
|
||||
parser.add_argument('-o', '--output', required=True,
|
||||
help='输出文件路径')
|
||||
parser.add_argument('-k', '--kernel-size', type=int, default=3,
|
||||
help='卷积核大小,必须是奇数 (默认: 3)')
|
||||
|
||||
# 高斯滤波参数
|
||||
parser.add_argument('-s', '--sigma', type=float, default=1.0,
|
||||
help='高斯标准差,用于gaussian滤波 (默认: 1.0)')
|
||||
|
||||
# 双边滤波参数
|
||||
parser.add_argument('-sc', '--sigma-color', type=float, default=75.0,
|
||||
help='颜色空间标准差,用于bilateral滤波 (默认: 75.0)')
|
||||
parser.add_argument('-ss', '--sigma-space', type=float, default=75.0,
|
||||
help='坐标空间标准差,用于bilateral滤波 (默认: 75.0)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
print("=" * 60)
|
||||
print("高光谱图像平滑滤波工具")
|
||||
print("=" * 60)
|
||||
print(f"输入文件: {args.input_path}")
|
||||
print(f"滤波类型: {args.filter_type}")
|
||||
print(f"波段索引: {args.band_index}")
|
||||
print(f"卷积核大小: {args.kernel_size}")
|
||||
print(f"输出文件: {args.output}")
|
||||
|
||||
# 根据滤波类型显示额外参数
|
||||
if args.filter_type == 'gaussian':
|
||||
print(f"高斯标准差: {args.sigma}")
|
||||
elif args.filter_type == 'bilateral':
|
||||
print(f"颜色标准差: {args.sigma_color}")
|
||||
print(f"空间标准差: {args.sigma_space}")
|
||||
print()
|
||||
|
||||
# 构建参数字典
|
||||
kwargs = {'kernel_size': args.kernel_size}
|
||||
|
||||
if args.filter_type == 'gaussian':
|
||||
kwargs['sigma'] = args.sigma
|
||||
elif args.filter_type == 'bilateral':
|
||||
kwargs['sigma_color'] = args.sigma_color
|
||||
kwargs['sigma_space'] = args.sigma_space
|
||||
|
||||
# 创建滤波器实例并处理图像
|
||||
filter_obj = HyperspectralImageFilter()
|
||||
|
||||
filter_obj.process_image(
|
||||
input_path=args.input_path,
|
||||
output_path=args.output,
|
||||
filter_type=args.filter_type,
|
||||
band_index=args.band_index,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("滤波处理完成!")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmoothFilterConfig:
|
||||
"""平滑滤波配置类"""
|
||||
input_path: str
|
||||
filter_type: str
|
||||
band_index: int
|
||||
kernel_size: int
|
||||
output_dir: str
|
||||
sigma: float = 1.0
|
||||
sigma_color: float = 75.0
|
||||
sigma_space: float = 75.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # 运行原始调试
|
||||
# test_data_preservation() # 测试数据格式保持
|
||||
|
||||
745
fliter_method/morphological_fliter.py
Normal file
745
fliter_method/morphological_fliter.py
Normal file
@ -0,0 +1,745 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import warnings
|
||||
from typing import Tuple, List, Dict, Optional, Union
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import spectral as spy
|
||||
from spectral import envi
|
||||
import matplotlib.pyplot as plt
|
||||
from skimage import morphology
|
||||
from scipy import ndimage
|
||||
import struct
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class HyperspectralMorphologyProcessor:
|
||||
"""高光谱图像形态学处理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.header = None
|
||||
self.wavelengths = None
|
||||
self.shape = None
|
||||
self.dtype = None
|
||||
|
||||
def load_hyperspectral_image(self, file_path: str, format_type: str = None) -> Tuple[np.ndarray, Dict]:
|
||||
"""
|
||||
加载高光谱图像(支持bil、bip、bsq、dat格式)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
file_path : str
|
||||
图像文件路径(可以是.hdr、.bil、.bip、.bsq、.dat)
|
||||
format_type : str, optional
|
||||
格式类型('bil'、'bip'、'bsq'、'dat'),如果为None则自动检测
|
||||
|
||||
Returns:
|
||||
--------
|
||||
data : np.ndarray
|
||||
高光谱数据立方体 (rows, cols, bands)
|
||||
header : Dict
|
||||
头文件信息
|
||||
"""
|
||||
try:
|
||||
# 自动检测格式
|
||||
if format_type is None:
|
||||
format_type = self._detect_format(file_path)
|
||||
|
||||
# 如果是.hdr文件,直接使用spectral加载
|
||||
if file_path.lower().endswith('.hdr'):
|
||||
img = envi.open(file_path)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 如果是其他格式,尝试查找对应的.hdr文件
|
||||
else:
|
||||
# 查找可能的hdr文件
|
||||
hdr_candidates = [
|
||||
file_path.rsplit('.', 1)[0] + '.hdr',
|
||||
file_path + '.hdr',
|
||||
file_path[:-4] + '.hdr' if len(file_path) > 4 else None
|
||||
]
|
||||
|
||||
hdr_file = None
|
||||
for candidate in hdr_candidates:
|
||||
if candidate and os.path.exists(candidate):
|
||||
hdr_file = candidate
|
||||
break
|
||||
|
||||
if hdr_file is None:
|
||||
raise FileNotFoundError(f"未找到头文件(.hdr)用于 {file_path}")
|
||||
|
||||
# 加载图像
|
||||
img = envi.open(hdr_file)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 根据格式调整数据布局
|
||||
if format_type.lower() == 'bil':
|
||||
# BIL: 波段交错按行
|
||||
data = self._convert_to_bil(data)
|
||||
elif format_type.lower() == 'bip':
|
||||
# BIP: 波段交错按像素 (spectral默认)
|
||||
pass # spectral默认就是BIP
|
||||
elif format_type.lower() == 'bsq':
|
||||
# BSQ: 波段顺序
|
||||
data = self._convert_to_bsq(data)
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.shape = data.shape
|
||||
self.dtype = data.dtype
|
||||
|
||||
# 提取波长信息
|
||||
if 'wavelength' in header:
|
||||
self.wavelengths = np.array([float(w) for w in header['wavelength']])
|
||||
else:
|
||||
self.wavelengths = None
|
||||
|
||||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}, 格式={format_type}")
|
||||
if self.wavelengths is not None:
|
||||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载图像失败: {e}")
|
||||
|
||||
def _detect_format(self, file_path: str) -> str:
|
||||
"""自动检测图像格式"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
|
||||
if ext == '.hdr':
|
||||
# 读取hdr文件内容检测数据格式
|
||||
with open(file_path, 'r') as f:
|
||||
content = f.read()
|
||||
if 'interleave = bil' in content.lower():
|
||||
return 'bil'
|
||||
elif 'interleave = bsq' in content.lower():
|
||||
return 'bsq'
|
||||
elif 'interleave = bip' in content.lower():
|
||||
return 'bip'
|
||||
else:
|
||||
# 默认假设为BIP
|
||||
return 'bip'
|
||||
elif ext in ['.bil', '.bip', '.bsq', '.dat']:
|
||||
# 从扩展名判断
|
||||
return ext[1:] # 去掉点号
|
||||
else:
|
||||
# 默认假设为BIP
|
||||
return 'bip'
|
||||
|
||||
def _convert_to_bil(self, data: np.ndarray) -> np.ndarray:
|
||||
"""将数据转换为BIL格式"""
|
||||
rows, cols, bands = data.shape
|
||||
bil_data = np.zeros((rows, bands, cols), dtype=data.dtype)
|
||||
for i in range(rows):
|
||||
for b in range(bands):
|
||||
bil_data[i, b, :] = data[i, :, b]
|
||||
return bil_data
|
||||
|
||||
def _convert_to_bsq(self, data: np.ndarray) -> np.ndarray:
|
||||
"""将数据转换为BSQ格式"""
|
||||
rows, cols, bands = data.shape
|
||||
bsq_data = np.zeros((bands, rows, cols), dtype=data.dtype)
|
||||
for b in range(bands):
|
||||
bsq_data[b, :, :] = data[:, :, b]
|
||||
return bsq_data
|
||||
|
||||
def extract_band(self, band_index: int) -> np.ndarray:
|
||||
"""
|
||||
提取指定波段
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
band_index : int
|
||||
波段索引(从0开始)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
band_data : np.ndarray
|
||||
单波段图像 (rows, cols)
|
||||
"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先加载图像数据")
|
||||
|
||||
if band_index < 0 or band_index >= self.shape[2]:
|
||||
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {self.shape[2]-1}]")
|
||||
|
||||
# 根据数据布局提取波段
|
||||
if len(self.data.shape) == 3:
|
||||
# 标准形状 (rows, cols, bands)
|
||||
# 使用squeeze()来确保移除单维度
|
||||
band_data = np.squeeze(self.data[:, :, band_index])
|
||||
print(f"提取前形状: {self.data.shape}, band_index: {band_index}")
|
||||
print(f"提取后初始形状: {band_data.shape}")
|
||||
|
||||
# 确保最终是2D
|
||||
if len(band_data.shape) != 2:
|
||||
raise ValueError(f"无法提取2D波段数据,最终形状: {band_data.shape}")
|
||||
|
||||
elif len(self.data.shape) == 2:
|
||||
# 已经是单波段
|
||||
band_data = self.data
|
||||
else:
|
||||
raise ValueError(f"不支持的数据形状: {self.data.shape}")
|
||||
|
||||
print(f"提取波段 {band_index}: 最终形状={band_data.shape}")
|
||||
|
||||
return band_data
|
||||
|
||||
def apply_morphology_operation(self, band_data: np.ndarray,
|
||||
operation: str = 'dilation',
|
||||
se_shape: str = 'disk',
|
||||
se_size: int = 3,
|
||||
**kwargs) -> np.ndarray:
|
||||
"""
|
||||
应用形态学操作
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
band_data : np.ndarray
|
||||
单波段图像数据 (2D)
|
||||
operation : str
|
||||
形态学操作类型:
|
||||
- 'dilation': 膨胀
|
||||
- 'erosion': 腐蚀
|
||||
- 'opening': 开运算
|
||||
- 'closing': 闭运算
|
||||
- 'gradient': 形态学梯度 (膨胀 - 腐蚀)
|
||||
- 'tophat': 顶帽变换 (原图 - 开运算)
|
||||
- 'bottomhat': 底帽变换 (闭运算 - 原图)
|
||||
- 'reconstruction': 形态学重建
|
||||
se_shape : str
|
||||
结构元素形状: 'disk', 'square', 'rectangle', 'diamond'
|
||||
se_size : int
|
||||
结构元素大小
|
||||
|
||||
Returns:
|
||||
--------
|
||||
result : np.ndarray
|
||||
处理后的图像
|
||||
"""
|
||||
# 确保输入数据是2D的
|
||||
if len(band_data.shape) != 2:
|
||||
raise ValueError(f"输入数据必须是2D的,当前形状: {band_data.shape}")
|
||||
|
||||
# 创建结构元素
|
||||
selem = self._create_structuring_element(se_shape, se_size)
|
||||
|
||||
# 转换为浮点数以保证精度
|
||||
data_float = band_data.astype(np.float32)
|
||||
|
||||
# 应用形态学操作
|
||||
if operation == 'dilation':
|
||||
result = ndimage.grey_dilation(data_float, footprint=selem)
|
||||
|
||||
elif operation == 'erosion':
|
||||
result = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
|
||||
elif operation == 'opening':
|
||||
# 开运算: 先腐蚀后膨胀
|
||||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
result = ndimage.grey_dilation(eroded, footprint=selem)
|
||||
|
||||
elif operation == 'closing':
|
||||
# 闭运算: 先膨胀后腐蚀
|
||||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||||
result = ndimage.grey_erosion(dilated, footprint=selem)
|
||||
|
||||
elif operation == 'gradient':
|
||||
# 形态学梯度: 膨胀 - 腐蚀
|
||||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
result = dilated - eroded
|
||||
|
||||
elif operation == 'tophat':
|
||||
# 顶帽变换: 原图 - 开运算
|
||||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
opened = ndimage.grey_dilation(eroded, footprint=selem)
|
||||
result = data_float - opened
|
||||
|
||||
elif operation == 'bottomhat':
|
||||
# 底帽变换: 闭运算 - 原图
|
||||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||||
closed = ndimage.grey_erosion(dilated, footprint=selem)
|
||||
result = closed - data_float
|
||||
|
||||
elif operation == 'reconstruction':
|
||||
# 形态学重建(基于膨胀的重建)
|
||||
# 使用标记图像(这里使用腐蚀后的图像作为标记)
|
||||
marker = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
result = self._morphological_reconstruction(marker, data_float, selem)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的形态学操作: {operation}")
|
||||
|
||||
# 转换为原始数据类型
|
||||
if operation in ['gradient', 'tophat', 'bottomhat']:
|
||||
# 这些操作可能产生负值,保持浮点类型
|
||||
return result
|
||||
else:
|
||||
return self._convert_to_original_type(result, band_data.dtype)
|
||||
|
||||
def _create_structuring_element(self, shape: str, size: int) -> np.ndarray:
|
||||
"""创建结构元素"""
|
||||
if shape == 'disk':
|
||||
# 创建圆形结构元素
|
||||
radius = size // 2
|
||||
y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
|
||||
mask = x*x + y*y <= radius*radius
|
||||
selem = np.zeros((2*radius+1, 2*radius+1), dtype=bool)
|
||||
selem[mask] = True
|
||||
elif shape == 'square':
|
||||
# 创建方形结构元素
|
||||
selem = np.ones((size, size), dtype=bool)
|
||||
elif shape == 'rectangle':
|
||||
# 创建矩形结构元素
|
||||
selem = np.ones((size, size*2), dtype=bool)
|
||||
elif shape == 'diamond':
|
||||
# 创建菱形结构元素
|
||||
selem = morphology.diamond(size)
|
||||
else:
|
||||
raise ValueError(f"不支持的结构元素形状: {shape}")
|
||||
|
||||
return selem
|
||||
|
||||
def _morphological_reconstruction(self, marker: np.ndarray, mask: np.ndarray,
|
||||
selem: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
形态学重建(基于膨胀)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
marker : np.ndarray
|
||||
标记图像
|
||||
mask : np.ndarray
|
||||
掩模图像
|
||||
selem : np.ndarray
|
||||
结构元素
|
||||
|
||||
Returns:
|
||||
--------
|
||||
reconstructed : np.ndarray
|
||||
重建后的图像
|
||||
"""
|
||||
# 确保标记图像不超过掩模图像
|
||||
marker = np.minimum(marker, mask)
|
||||
|
||||
# 迭代膨胀直到收敛
|
||||
prev_marker = np.zeros_like(marker)
|
||||
while not np.array_equal(marker, prev_marker):
|
||||
prev_marker = marker.copy()
|
||||
# 条件膨胀:在掩模限制下膨胀
|
||||
dilated = ndimage.grey_dilation(marker, footprint=selem)
|
||||
marker = np.minimum(dilated, mask)
|
||||
|
||||
return marker
|
||||
|
||||
def _convert_to_original_type(self, data: np.ndarray, original_dtype: np.dtype) -> np.ndarray:
|
||||
"""将数据转换回原始数据类型"""
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
# 对于整数类型,进行裁剪和取整
|
||||
data = np.clip(data, np.iinfo(original_dtype).min, np.iinfo(original_dtype).max)
|
||||
return data.astype(original_dtype)
|
||||
else:
|
||||
# 对于浮点类型,直接转换
|
||||
return data.astype(original_dtype)
|
||||
|
||||
def apply_to_multiple_bands(self, band_indices: List[int], operation: str,
|
||||
se_shape: str = 'disk', se_size: int = 3) -> Dict[int, np.ndarray]:
|
||||
"""
|
||||
对多个波段应用形态学操作
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
band_indices : List[int]
|
||||
波段索引列表
|
||||
operation : str
|
||||
形态学操作类型
|
||||
se_shape : str
|
||||
结构元素形状
|
||||
se_size : int
|
||||
结构元素大小
|
||||
|
||||
Returns:
|
||||
--------
|
||||
results : Dict[int, np.ndarray]
|
||||
每个波段的处理结果
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for band_idx in band_indices:
|
||||
print(f"处理波段 {band_idx}...")
|
||||
band_data = self.extract_band(band_idx)
|
||||
# extract_band 现在应该总是返回2D数据,但为了安全起见检查一下
|
||||
if len(band_data.shape) != 2:
|
||||
raise ValueError(f"波段 {band_idx} 数据不是2D的,形状: {band_data.shape}")
|
||||
result = self.apply_morphology_operation(band_data, operation, se_shape, se_size)
|
||||
results[band_idx] = result
|
||||
|
||||
return results
|
||||
|
||||
def save_as_envi(self, data: np.ndarray, output_path: str,
|
||||
description: str = "形态学处理结果") -> None:
|
||||
"""
|
||||
保存为ENVI格式的dat和hdr文件
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
data : np.ndarray
|
||||
要保存的数据(单波段2D数组)
|
||||
output_path : str
|
||||
输出文件路径(不含扩展名)
|
||||
description : str
|
||||
图像描述
|
||||
"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
# 确保输出文件扩展名为.dat
|
||||
if not output_path.lower().endswith('.dat'):
|
||||
dat_path = output_path + '.dat'
|
||||
else:
|
||||
dat_path = output_path
|
||||
|
||||
# 保存为二进制dat文件
|
||||
data.tofile(dat_path)
|
||||
|
||||
# 创建对应的hdr文件
|
||||
hdr_path = dat_path.rsplit('.', 1)[0] + '.hdr'
|
||||
self._create_morphology_hdr_file(hdr_path, data.shape, description, data.dtype)
|
||||
|
||||
print(f"ENVI格式已保存:")
|
||||
print(f" 数据文件: {dat_path}")
|
||||
print(f" 头文件: {hdr_path}")
|
||||
print(f" 数据形状: {data.shape}, 数据类型: {data.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"保存ENVI文件失败: {e}")
|
||||
|
||||
def _create_morphology_hdr_file(self, hdr_path: str, data_shape: Tuple,
|
||||
description: str, data_dtype: np.dtype = None) -> None:
|
||||
"""
|
||||
创建形态学处理结果的ENVI头文件
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
hdr_path : str
|
||||
头文件路径
|
||||
data_shape : Tuple
|
||||
数据形状
|
||||
description : str
|
||||
图像描述
|
||||
data_dtype : np.dtype, optional
|
||||
数据类型,如果为None则使用默认值
|
||||
"""
|
||||
try:
|
||||
# 从原始头文件获取基本信息(如果有的话)
|
||||
if self.header is not None:
|
||||
# 复制原始头文件的关键信息
|
||||
samples = self.header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
|
||||
lines = self.header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
|
||||
bands = 1 # 形态学处理结果是单波段
|
||||
interleave = 'bsq'
|
||||
|
||||
# 根据数据类型确定ENVI数据类型
|
||||
if data_dtype is not None:
|
||||
dtype = data_dtype
|
||||
else:
|
||||
dtype = np.float32 # 默认类型
|
||||
|
||||
if np.issubdtype(dtype, np.float32):
|
||||
data_type = 4 # float32
|
||||
elif np.issubdtype(dtype, np.float64):
|
||||
data_type = 5 # float64
|
||||
elif np.issubdtype(dtype, np.int32):
|
||||
data_type = 3 # int32
|
||||
elif np.issubdtype(dtype, np.int16):
|
||||
data_type = 2 # int16
|
||||
elif np.issubdtype(dtype, np.uint16):
|
||||
data_type = 12 # uint16
|
||||
else:
|
||||
data_type = 4 # 默认float32
|
||||
|
||||
byte_order = self.header.get('byte order', 0)
|
||||
wavelength_units = self.header.get('wavelength units', 'nm')
|
||||
else:
|
||||
# 默认值(适用于2D形态学处理结果)
|
||||
if len(data_shape) == 2:
|
||||
lines, samples = data_shape
|
||||
else:
|
||||
lines = data_shape[0]
|
||||
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
|
||||
bands = 1
|
||||
interleave = 'bsq'
|
||||
data_type = 4 # float32(形态学处理通常涉及浮点数)
|
||||
byte_order = 0
|
||||
wavelength_units = 'Unknown'
|
||||
|
||||
# 写入hdr文件
|
||||
with open(hdr_path, 'w') as f:
|
||||
f.write("ENVI\n")
|
||||
f.write("description = {\n")
|
||||
f.write(f" {description}\n")
|
||||
f.write(" 单波段形态学处理结果\n")
|
||||
f.write("}\n")
|
||||
f.write(f"samples = {samples}\n")
|
||||
f.write(f"lines = {lines}\n")
|
||||
f.write(f"bands = {bands}\n")
|
||||
f.write(f"header offset = 0\n")
|
||||
f.write(f"file type = ENVI Standard\n")
|
||||
f.write(f"data type = {data_type}\n")
|
||||
f.write(f"interleave = {interleave}\n")
|
||||
f.write(f"byte order = {byte_order}\n")
|
||||
|
||||
# 如果有波长信息,添加虚拟波长
|
||||
if self.header and 'wavelength' in self.header:
|
||||
f.write("wavelength = {\n")
|
||||
f.write(" 形态学处理结果\n")
|
||||
f.write("}\n")
|
||||
f.write(f"wavelength units = {wavelength_units}\n")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"创建HDR文件失败: {e}")
|
||||
|
||||
def visualize_results(self, original_band: np.ndarray,
|
||||
processed_band: np.ndarray,
|
||||
operation_name: str,
|
||||
save_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
可视化原始波段和处理后的波段
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
original_band : np.ndarray
|
||||
原始波段数据
|
||||
processed_band : np.ndarray
|
||||
处理后的波段数据
|
||||
operation_name : str
|
||||
操作名称
|
||||
save_path : str, optional
|
||||
保存图像路径
|
||||
"""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
|
||||
# 原始图像
|
||||
im1 = axes[0].imshow(original_band, cmap='gray')
|
||||
axes[0].set_title('原始波段')
|
||||
axes[0].axis('off')
|
||||
plt.colorbar(im1, ax=axes[0], shrink=0.8)
|
||||
|
||||
# 处理后的图像
|
||||
im2 = axes[1].imshow(processed_band, cmap='gray')
|
||||
axes[1].set_title(f'{operation_name} 结果')
|
||||
axes[1].axis('off')
|
||||
plt.colorbar(im2, ax=axes[1], shrink=0.8)
|
||||
|
||||
# 差异图像
|
||||
diff = processed_band.astype(np.float32) - original_band.astype(np.float32)
|
||||
im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
|
||||
axes[2].set_title('差异 (处理后 - 原始)')
|
||||
axes[2].axis('off')
|
||||
plt.colorbar(im3, ax=axes[2], shrink=0.8)
|
||||
|
||||
plt.suptitle(f'形态学操作: {operation_name}', fontsize=16)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"可视化结果已保存到: {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:高光谱图像形态学处理"""
|
||||
import argparse
|
||||
|
||||
# 支持的形态学操作列表
|
||||
morph_operations = [
|
||||
'dilation', 'erosion', 'opening', 'closing',
|
||||
'gradient', 'tophat', 'bottomhat', 'reconstruction'
|
||||
]
|
||||
|
||||
# 支持的结构元素形状
|
||||
se_shapes = ['disk', 'square', 'rectangle', 'diamond']
|
||||
|
||||
parser = argparse.ArgumentParser(description='高光谱图像形态学处理工具')
|
||||
parser.add_argument('input_file', help='输入高光谱图像文件路径')
|
||||
parser.add_argument('--format', '-f', default='auto',
|
||||
choices=['auto', 'bil', 'bip', 'bsq', 'dat'],
|
||||
help='图像格式 (默认: auto)')
|
||||
parser.add_argument('--band', '-b', type=int, default=0,
|
||||
help='要处理的波段索引 (默认: 0)')
|
||||
parser.add_argument('--bands', '-B', type=str, default=None,
|
||||
help='多个波段索引,用逗号分隔,如 "0,10,20"')
|
||||
parser.add_argument('--operation', '-o', default='dilation',
|
||||
choices=morph_operations,
|
||||
help=f'形态学操作类型 (默认: dilation)')
|
||||
parser.add_argument('--se_shape', '-s', default='disk',
|
||||
choices=se_shapes,
|
||||
help='结构元素形状 (默认: disk)')
|
||||
parser.add_argument('--se_size', '-S', type=int, default=3,
|
||||
help='结构元素大小 (默认: 3)')
|
||||
parser.add_argument('--output_dir', '-d', default='output',
|
||||
help='输出目录 (默认: output)')
|
||||
parser.add_argument('--output_name', '-n', default='morphology_result',
|
||||
help='输出文件名 (不含扩展名) (默认: morphology_result)')
|
||||
parser.add_argument('--visualize', '-v', action='store_true',
|
||||
help='是否生成可视化结果')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 初始化处理器
|
||||
processor = HyperspectralMorphologyProcessor()
|
||||
|
||||
# 确定格式
|
||||
format_type = None if args.format == 'auto' else args.format
|
||||
|
||||
# 加载图像
|
||||
print(f"加载图像: {args.input_file}")
|
||||
data, header = processor.load_hyperspectral_image(args.input_file, format_type)
|
||||
|
||||
# 确定要处理的波段
|
||||
if args.bands:
|
||||
band_indices = [int(b.strip()) for b in args.bands.split(',')]
|
||||
print(f"处理波段: {band_indices}")
|
||||
single_band = False
|
||||
else:
|
||||
band_indices = [args.band]
|
||||
print(f"处理波段: {args.band}")
|
||||
single_band = True
|
||||
|
||||
# 应用形态学操作
|
||||
if single_band:
|
||||
# 单波段处理
|
||||
band_data = processor.extract_band(args.band)
|
||||
print(f"主函数中波段数据形状: {band_data.shape}")
|
||||
result = processor.apply_morphology_operation(
|
||||
band_data, args.operation, args.se_shape, args.se_size
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{args.band}")
|
||||
processor.save_as_envi(result, output_path,
|
||||
f"{args.operation.capitalize()} 处理结果 - 波段 {args.band}")
|
||||
|
||||
# 可视化
|
||||
if args.visualize:
|
||||
vis_path = os.path.join(args.output_dir, f"visualization_band{args.band}.png")
|
||||
processor.visualize_results(band_data, result, args.operation, vis_path)
|
||||
|
||||
# 打印统计信息
|
||||
print(f"\n=== 统计信息 ===")
|
||||
print(f"原始波段 {args.band}: min={band_data.min():.4f}, max={band_data.max():.4f}, "
|
||||
f"mean={band_data.mean():.4f}")
|
||||
print(f"处理后: min={result.min():.4f}, max={result.max():.4f}, "
|
||||
f"mean={result.mean():.4f}")
|
||||
|
||||
else:
|
||||
# 多波段处理
|
||||
results = processor.apply_to_multiple_bands(
|
||||
band_indices, args.operation, args.se_shape, args.se_size
|
||||
)
|
||||
|
||||
# 保存每个波段的结果
|
||||
for band_idx, result in results.items():
|
||||
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{band_idx}")
|
||||
processor.save_as_envi(result, output_path,
|
||||
f"{args.operation.capitalize()} 处理结果 - 波段 {band_idx}")
|
||||
|
||||
# 如果只有少量波段,可以创建组合可视化
|
||||
if len(band_indices) <= 4 and args.visualize:
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, band_idx in enumerate(band_indices[:4]):
|
||||
original_band = processor.extract_band(band_idx)
|
||||
processed_band = results[band_idx]
|
||||
|
||||
axes[idx].imshow(processed_band, cmap='gray')
|
||||
axes[idx].set_title(f'波段 {band_idx} - {args.operation}')
|
||||
axes[idx].axis('off')
|
||||
|
||||
plt.suptitle(f'多波段形态学处理: {args.operation}', fontsize=16)
|
||||
plt.tight_layout()
|
||||
|
||||
vis_path = os.path.join(args.output_dir, "multi_band_visualization.png")
|
||||
plt.savefig(vis_path, dpi=300, bbox_inches='tight')
|
||||
print(f"多波段可视化结果已保存到: {vis_path}")
|
||||
plt.show()
|
||||
|
||||
print(f"\n✓ 形态学处理完成!")
|
||||
print(f"输出目录: {args.output_dir}")
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
# 使用示例函数
|
||||
def run_example():
|
||||
"""运行示例"""
|
||||
# 示例1:对单个波段进行膨胀操作
|
||||
processor = HyperspectralMorphologyProcessor()
|
||||
|
||||
# 加载图像(请替换为您的实际文件路径)
|
||||
input_file = "path/to/your/hyperspectral_image.hdr" # 或 .bil, .bip, .bsq, .dat
|
||||
data, header = processor.load_hyperspectral_image(input_file, format_type='bip')
|
||||
|
||||
# 提取波段(例如,提取第50个波段)
|
||||
band_idx = 50
|
||||
band_data = processor.extract_band(band_idx)
|
||||
|
||||
# 应用不同的形态学操作
|
||||
operations = ['dilation', 'erosion', 'opening', 'closing',
|
||||
'gradient', 'tophat', 'bottomhat', 'reconstruction']
|
||||
|
||||
results = {}
|
||||
for operation in operations:
|
||||
print(f"\n应用 {operation} 操作...")
|
||||
try:
|
||||
result = processor.apply_morphology_operation(
|
||||
band_data, operation, se_shape='disk', se_size=3
|
||||
)
|
||||
results[operation] = result
|
||||
|
||||
# 保存结果
|
||||
output_path = f"output/{operation}_band{band_idx}"
|
||||
processor.save_as_envi(result, output_path,
|
||||
f"{operation.capitalize()} 处理结果")
|
||||
|
||||
# 可视化
|
||||
processor.visualize_results(band_data, result, operation,
|
||||
save_path=f"output/visualization_{operation}.png")
|
||||
|
||||
except Exception as e:
|
||||
print(f"操作 {operation} 失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class MorphologicalFilterConfig:
|
||||
"""形态学滤波配置类"""
|
||||
input_path: str
|
||||
operation: str
|
||||
band_index: int
|
||||
kernel_size: int
|
||||
output_dir: str
|
||||
se_shape: str = 'disk'
|
||||
se_size: int = 3
|
||||
format_type: str = 'auto'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user