增加模块;增加主调用命令

This commit is contained in:
2026-01-07 16:36:47 +08:00
commit 2d4b170a45
109 changed files with 55763 additions and 0 deletions

View File

@ -0,0 +1,751 @@
import numpy as np
import cv2
import os
from typing import Tuple, Optional, Dict, Union
from scipy import ndimage
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')
try:
import spectral as spy
from spectral import envi
HAS_SPECTRAL = True
except ImportError:
HAS_SPECTRAL = False
print("警告: 未安装spectral库只能处理RGB图像")
class HyperspectralImageFilter:
"""高光谱图像平滑滤波器
使用示例:
# 创建滤波器实例
filter_obj = HyperspectralImageFilter()
# 处理图像 - 均值滤波
dat_path, hdr_path = filter_obj.process_image(
input_path='input.hdr',
output_path='output_mean',
filter_type='mean',
kernel_size=5
)
# 处理图像 - 高斯滤波
dat_path, hdr_path = filter_obj.process_image(
input_path='input.hdr',
output_path='output_gaussian',
filter_type='gaussian',
kernel_size=3,
sigma=2.0
)
# 处理图像 - 双边滤波
dat_path, hdr_path = filter_obj.process_image(
input_path='input.hdr',
output_path='output_bilateral',
filter_type='bilateral',
kernel_size=7,
sigma_color=50.0,
sigma_space=50.0
)
# 或者直接使用apply_filter方法
data, header = filter_obj.load_image('input.hdr')
filtered_data = filter_obj.apply_filter('mean', kernel_size=5)
dat_path, hdr_path = filter_obj.save_envi('output', filtered_data, 'mean', header)
"""
def __init__(self):
self.data = None
self.header = None
self.wavelengths = None
self.is_hyperspectral = False
self._selected_band_index = None # 记录选择的波段索引
self._original_dtype = None # 记录原始数据类型
self._original_range = None # 记录原始数据范围
def load_image(self, image_path: str, band_index: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
"""
加载图像文件
Parameters:
image_path: 图像文件路径 (.hdr/.dat 或 .jpg/.png/.tif)
band_index: 对于高光谱图像,指定要处理的波段索引 (None表示使用所有波段)
Returns:
处理后的图像数据和元数据
"""
try:
# 检查文件扩展名
file_ext = os.path.splitext(image_path)[1].lower()
if file_ext in ['.hdr']:
# ENVI高光谱图像
if not HAS_SPECTRAL:
raise ImportError("需要安装spectral库来处理ENVI格式文件")
return self._load_envi_image(image_path, band_index)
elif file_ext in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']:
# RGB图像
return self._load_rgb_image(image_path)
else:
raise ValueError(f"不支持的文件格式: {file_ext}")
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def _load_envi_image(self, hdr_path: str, band_index: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
"""加载ENVI格式高光谱图像"""
try:
# 读取ENVI文件
img = envi.open(hdr_path)
data = img.load()
header = dict(img.metadata)
# 提取波长信息
wavelengths = None
if 'wavelength' in header:
try:
wavelengths = np.array([float(w) for w in header['wavelength']])
except:
wavelengths = None
# 如果指定了波段,选择特定波段
if band_index is not None:
if band_index < 0 or band_index >= data.shape[2]:
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {data.shape[2]-1}]")
print(f"原始数据形状: {data.shape}")
# 记录选择的波段索引,用于头文件生成
self._selected_band_index = band_index
data = data[:, :, band_index]
print(f"选择波段 {band_index} 后的数据形状: {data.shape}")
# 如果结果是形状 (H, W, 1),压缩为 (H, W)
if len(data.shape) == 3 and data.shape[2] == 1:
data = np.squeeze(data, axis=2) # 压缩最后一个维度
print(f"压缩单波段数据为2D: {data.shape}")
print(f"选择波段 {band_index} 进行处理")
if wavelengths is not None:
print(f"波长: {wavelengths[band_index]:.1f} nm")
# 选择了特定波段数据变为2D
self.is_hyperspectral = False
else:
print("处理所有波段")
self.is_hyperspectral = True
self._selected_band_index = None
# 记录原始数据类型和范围
self._original_dtype = data.dtype
if data.size > 0: # 确保数据不为空
self._original_range = (data.min(), data.max())
print(f"原始数据范围: {self._original_range[0]} - {self._original_range[1]}")
self.data = data
self.header = header
self.wavelengths = wavelengths
print(f"成功加载ENVI图像: 形状={data.shape}, 数据类型={data.dtype}")
if wavelengths is not None and band_index is not None:
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载ENVI图像失败: {e}")
def _load_rgb_image(self, image_path: str) -> Tuple[np.ndarray, Dict]:
"""加载RGB图像"""
try:
# 使用OpenCV读取图像
data = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
if data is None:
raise ValueError(f"无法读取图像文件: {image_path}")
# 转换为RGB格式 (OpenCV默认是BGR)
if len(data.shape) == 3 and data.shape[2] == 3:
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
# 记录原始数据类型和范围
self._original_dtype = data.dtype
if data.size > 0:
self._original_range = (data.min(), data.max())
print(f"原始数据范围: {self._original_range[0]} - {self._original_range[1]}")
# 创建基本的元数据
header = {
'samples': data.shape[1],
'lines': data.shape[0],
'bands': data.shape[2] if len(data.shape) == 3 else 1,
'data_type': self._get_envi_data_type(data.dtype),
'interleave': 'bsq',
'byte_order': 0
}
self.data = data
self.header = header
self.wavelengths = None
self.is_hyperspectral = False
print(f"成功加载RGB图像: 形状={data.shape}, 数据类型={data.dtype}")
return data, header
except Exception as e:
raise IOError(f"加载RGB图像失败: {e}")
def _get_envi_data_type(self, dtype: np.dtype) -> int:
"""获取ENVI数据类型代码"""
if dtype == np.uint8:
return 1
elif dtype == np.int16:
return 2
elif dtype == np.int32:
return 3
elif dtype == np.float32:
return 4
elif dtype == np.float64:
return 5
elif dtype == np.uint16:
return 12
elif dtype == np.uint32:
return 13
else:
return 4 # 默认float32
def apply_filter(self, filter_type: str, **kwargs) -> np.ndarray:
"""
应用指定的滤波器
Parameters:
filter_type: 滤波器类型 ('mean', 'median', 'gaussian', 'bilateral')
**kwargs: 各滤波器的超参数
- kernel_size: 内核大小 (奇数默认3)
- sigma: 高斯滤波的标准差 (默认1.0)
- sigma_color: 双边滤波的颜色空间标准差 (默认75.0)
- sigma_space: 双边滤波的空间标准差 (默认75.0)
Returns:
滤波后的图像数据
"""
if self.data is None:
raise ValueError("请先加载图像数据")
# 获取参数,设置默认值
kernel_size = kwargs.get('kernel_size', 3)
sigma = kwargs.get('sigma', 1.0)
sigma_color = kwargs.get('sigma_color', 75.0)
sigma_space = kwargs.get('sigma_space', 75.0)
if kernel_size % 2 == 0:
kernel_size += 1 # 确保为奇数
print(f"应用{filter_type}滤波器,参数: kernel_size={kernel_size}")
if filter_type.lower() == 'mean':
return self._mean_filter(kernel_size)
elif filter_type.lower() == 'median':
return self._median_filter(kernel_size)
elif filter_type.lower() == 'gaussian':
print(f" sigma={sigma}")
return self._gaussian_filter(kernel_size, sigma)
elif filter_type.lower() == 'bilateral':
print(f" sigma_color={sigma_color}, sigma_space={sigma_space}")
return self._bilateral_filter(kernel_size, sigma_color, sigma_space)
else:
raise ValueError(f"不支持的滤波器类型: {filter_type}")
def _mean_filter(self, kernel_size: int) -> np.ndarray:
"""均值滤波"""
try:
print(f"均值滤波开始 - 数据形状: {self.data.shape}, is_hyperspectral: {self.is_hyperspectral}")
# 创建均值内核
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) / (kernel_size * kernel_size)
# 检查数据维度
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
# 多波段高光谱图像 - 对每个波段分别处理
filtered_data = np.zeros_like(self.data, dtype=np.float32)
for band in range(self.data.shape[2]):
filtered_data[:, :, band] = ndimage.convolve(self.data[:, :, band].astype(np.float32), kernel)
print(f"均值滤波完成 - 处理了 {self.data.shape[2]} 个波段")
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
# 单波段3D数据 (H, W, 1) - 压缩并处理
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
filtered_data = ndimage.convolve(data_2d, kernel)
# 保持输出为3D形状以保持一致性
filtered_data = filtered_data[:, :, np.newaxis]
print(f"均值滤波完成 - 单波段3D数据 (H,W,1) -> (H,W,1)")
else:
# 2D图像或单波段图像
print(f"应用2D均值滤波到形状 {self.data.shape} 的数据")
filtered_data = ndimage.convolve(self.data.astype(np.float32), kernel)
print(f"2D均值滤波完成 - 输出形状: {filtered_data.shape}")
return filtered_data
except Exception as e:
raise RuntimeError(f"均值滤波失败: {e}")
def _median_filter(self, kernel_size: int) -> np.ndarray:
"""中值滤波"""
try:
# 检查数据维度
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
# 多波段高光谱图像 - 对每个波段分别处理
filtered_data = np.zeros_like(self.data, dtype=self.data.dtype)
for band in range(self.data.shape[2]):
filtered_data[:, :, band] = ndimage.median_filter(self.data[:, :, band], size=kernel_size)
print(f"中值滤波完成 - 处理了 {self.data.shape[2]} 个波段")
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
# 单波段3D数据 (H, W, 1) - 压缩并处理
data_2d = np.squeeze(self.data, axis=2)
filtered_data = ndimage.median_filter(data_2d, size=kernel_size)
# 保持输出为3D形状以保持一致性
filtered_data = filtered_data[:, :, np.newaxis]
print(f"中值滤波完成 - 单波段3D数据 (H,W,1) -> (H,W,1)")
else:
# 2D图像或单波段图像
filtered_data = ndimage.median_filter(self.data, size=kernel_size)
return filtered_data
except Exception as e:
raise RuntimeError(f"中值滤波失败: {e}")
def _gaussian_filter(self, kernel_size: int, sigma: float) -> np.ndarray:
"""高斯滤波"""
try:
# 检查数据维度
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
# 多波段高光谱图像 - 对每个波段分别处理
filtered_data = np.zeros_like(self.data, dtype=np.float32)
for band in range(self.data.shape[2]):
filtered_data[:, :, band] = ndimage.gaussian_filter(
self.data[:, :, band].astype(np.float32), sigma=sigma
)
print(f"高斯滤波完成 (sigma={sigma}) - 处理了 {self.data.shape[2]} 个波段")
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
# 单波段3D数据 (H, W, 1) - 压缩并处理
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
filtered_data = ndimage.gaussian_filter(data_2d, sigma=sigma)
# 保持输出为3D形状以保持一致性
filtered_data = filtered_data[:, :, np.newaxis]
print(f"高斯滤波完成 (sigma={sigma}) - 单波段3D数据 (H,W,1) -> (H,W,1)")
else:
# 2D图像或单波段图像
filtered_data = ndimage.gaussian_filter(self.data.astype(np.float32), sigma=sigma)
print(f"高斯滤波完成 (sigma={sigma})")
return filtered_data
except Exception as e:
raise RuntimeError(f"高斯滤波失败: {e}")
def _bilateral_filter(self, kernel_size: int, sigma_color: float, sigma_space: float) -> np.ndarray:
"""双边滤波"""
try:
# 双边滤波主要用于RGB图像或单波段图像
# 对于高光谱数据,我们可以对每个波段分别应用双边滤波
if len(self.data.shape) == 3 and self.data.shape[2] > 1 and self.is_hyperspectral:
# 多波段高光谱图像 - 对每个波段分别处理
filtered_data = np.zeros_like(self.data, dtype=np.float32)
for band in range(self.data.shape[2]):
band_data = self.data[:, :, band].astype(np.float32)
# 归一化到0-255范围进行双边滤波
if band_data.max() > band_data.min():
band_norm = ((band_data - band_data.min()) / (band_data.max() - band_data.min()) * 255).astype(np.uint8)
filtered_norm = cv2.bilateralFilter(band_norm, kernel_size, sigma_color, sigma_space)
# 恢复原始范围
filtered_data[:, :, band] = filtered_norm.astype(np.float32) / 255 * (band_data.max() - band_data.min()) + band_data.min()
else:
filtered_data[:, :, band] = band_data
print(f"双边滤波完成 - 处理了 {self.data.shape[2]} 个波段")
elif len(self.data.shape) == 3 and self.data.shape[2] == 3:
# RGB图像 - 直接使用OpenCV的双边滤波
filtered_data = cv2.bilateralFilter(self.data, kernel_size, sigma_color, sigma_space)
print("RGB双边滤波完成")
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
# 单波段3D数据 (H, W, 1) - 压缩并处理
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
if data_2d.max() > data_2d.min():
data_norm = ((data_2d - data_2d.min()) / (data_2d.max() - data_2d.min()) * 255).astype(np.uint8)
filtered_norm = cv2.bilateralFilter(data_norm, kernel_size, sigma_color, sigma_space)
filtered_data = filtered_norm.astype(np.float32) / 255 * (data_2d.max() - data_2d.min()) + data_2d.min()
# 保持输出为3D形状以保持一致性
filtered_data = filtered_data[:, :, np.newaxis]
else:
filtered_data = self.data.copy()
print("单波段双边滤波完成 (3D -> 3D)")
else:
# 2D单波段图像 - 转换为uint8进行处理
data_float = self.data.astype(np.float32)
if data_float.max() > data_float.min():
data_norm = ((data_float - data_float.min()) / (data_float.max() - data_float.min()) * 255).astype(np.uint8)
filtered_norm = cv2.bilateralFilter(data_norm, kernel_size, sigma_color, sigma_space)
filtered_data = filtered_norm.astype(np.float32) / 255 * (data_float.max() - data_float.min()) + data_float.min()
else:
filtered_data = data_float
print("单波段双边滤波完成")
return filtered_data
except Exception as e:
raise RuntimeError(f"双边滤波失败: {e}")
def save_envi(self, output_path: str, filtered_data: np.ndarray, filter_type: str,
original_header: Dict = None) -> Tuple[str, str]:
"""
保存滤波结果为ENVI格式
Parameters:
output_path: 输出文件路径(不含扩展名)
filtered_data: 滤波后的数据
filter_type: 滤波器类型,用于文件名
original_header: 原始图像的头文件信息
Returns:
数据文件路径和头文件路径
"""
try:
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# 生成文件名
base_name = os.path.basename(output_path)
dat_path = f"{output_path}_{filter_type}.dat"
hdr_path = f"{output_path}_{filter_type}.hdr"
# 保存数据文件 - 保持与原始数据一致的类型和范围
print(f"滤波前数据范围: {filtered_data.min():.2f} - {filtered_data.max():.2f}")
print(f"原始数据类型: {self._original_dtype}, 范围: {self._original_range}")
# 根据原始数据类型决定保存格式
if self._original_dtype is not None:
# 如果原始数据是整数类型,尝试保持整数格式
if np.issubdtype(self._original_dtype, np.integer):
# 对于整数类型的原始数据,检查滤波后的数据是否仍在合理范围内
if self._original_range is not None:
orig_min, orig_max = self._original_range
filtered_min, filtered_max = filtered_data.min(), filtered_data.max()
# 如果滤波后的数据范围仍在原始范围附近,使用原始数据类型
if (filtered_min >= orig_min - 100) and (filtered_max <= orig_max + 100):
print(f"保持原始整数数据类型: {self._original_dtype}")
filtered_data = np.clip(filtered_data, orig_min, orig_max).astype(self._original_dtype)
else:
print(f"数据范围变化较大使用float32保存")
filtered_data = filtered_data.astype(np.float32)
else:
filtered_data = filtered_data.astype(np.float32)
else:
# 原始数据就是浮点数保持float32
filtered_data = filtered_data.astype(np.float32)
else:
# 默认使用float32
filtered_data = filtered_data.astype(np.float32)
print(f"保存数据类型: {filtered_data.dtype}, 范围: {filtered_data.min():.2f} - {filtered_data.max():.2f}")
# 确保数据格式适合ENVI标准
if len(filtered_data.shape) == 2:
# 2D数据 (H, W) -> 3D数据 (H, W, 1)
filtered_data = filtered_data[:, :, np.newaxis]
print(f"转换2D数据为3D格式以符合ENVI标准: {filtered_data.shape}")
# ENVI使用BSQ格式存储
filtered_data.tofile(dat_path)
print(f"数据文件已保存: {dat_path}")
# 生成头文件
self._save_envi_header(hdr_path, filtered_data, filter_type, original_header)
print(f"头文件已保存: {hdr_path}")
return dat_path, hdr_path
except Exception as e:
raise IOError(f"保存ENVI文件失败: {e}")
def _save_envi_header(self, hdr_path: str, data: np.ndarray, filter_type: str,
original_header: Dict = None):
"""保存ENVI标准格式头文件"""
try:
from datetime import datetime
with open(hdr_path, 'w', encoding='utf-8') as f:
f.write("ENVI\n")
# 描述信息 - 包含处理时间和描述
current_time = datetime.now().strftime("%a %b %d %H:%M:%S %Y")
f.write("description = {\n")
f.write(f" Convolution Result [{current_time}]\n")
f.write("}\n")
# 图像尺寸
if len(data.shape) == 3:
samples, lines, bands = data.shape[1], data.shape[0], data.shape[2]
else:
samples, lines, bands = data.shape[1], data.shape[0], 1
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write("header offset = 0\n")
f.write("file type = ENVI Standard\n")
# 数据类型 - 使用实际保存的数据类型
if hasattr(self, '_original_dtype') and self._original_dtype is not None:
# 优先使用原始数据类型(如果范围合适)
data_type = self._get_envi_data_type(self._original_dtype)
else:
data_type = self._get_envi_data_type(data.dtype)
f.write(f"data type = {data_type}\n")
# 交织格式
f.write("interleave = bsq\n")
# 传感器类型
if original_header and 'sensor type' in original_header:
f.write(f"sensor type = {original_header['sensor type']}\n")
else:
f.write("sensor type = Unknown\n")
# 字节顺序
f.write("byte order = 0\n")
# 波长信息
if self.wavelengths is not None and len(self.wavelengths) > 0:
f.write("wavelength units = Nanometers\n")
if bands == 1 and hasattr(self, '_selected_band_index') and self._selected_band_index is not None:
# 单波段情况 - 使用选定波段的波长
band_idx = min(self._selected_band_index, len(self.wavelengths) - 1)
f.write(f"wavelength = {{\n")
f.write(f" {self.wavelengths[band_idx]:.6f}\n")
f.write("}\n")
else:
# 多波段情况 - 写入所有波长
f.write("wavelength = {\n")
wavelength_str = ", ".join([f"{w:.6f}" for w in self.wavelengths[:bands]])
f.write(f" {wavelength_str}\n")
f.write("}\n")
# 反射率比例因子
f.write("reflectance scale factor = 10000.000000\n")
# 从原始头文件复制其他重要参数
if original_header:
# 增益信息
if 'gain' in original_header:
f.write(f"gain = {original_header['gain']}\n")
# 分辨率信息
binning_keys = ['sample binning', 'spectral binning', 'line binning']
for key in binning_keys:
if key in original_header:
f.write(f"{key} = {original_header[key]}\n")
# 快门和帧率信息
if 'shutter' in original_header:
f.write(f"shutter = {original_header['shutter']}\n")
if 'framerate' in original_header:
f.write(f"framerate = {original_header['framerate']}\n")
# 设备序列号
if 'imager serial number' in original_header:
f.write(f"imager serial number = {original_header['imager serial number']}\n")
# 旋转矩阵
if 'rotation' in original_header:
rotation = original_header['rotation']
f.write(f"rotation = {rotation}\n")
# 标签
if 'label' in original_header:
f.write(f"label = {original_header['label']}\n")
# 处理历史
if 'history' in original_header:
f.write(f"history = {original_header['history']}\n")
except Exception as e:
raise IOError(f"保存头文件失败: {e}")
def process_image(self, input_path: str, output_path: str, filter_type: str,
band_index: Optional[int] = None, **kwargs) -> Tuple[str, str]:
"""
完整的图像处理流程
Parameters:
input_path: 输入图像路径
output_path: 输出路径(不含扩展名)
filter_type: 滤波器类型
band_index: 波段索引(可选)
**kwargs: 各滤波器的超参数
- kernel_size: 内核大小 (奇数默认3)
- sigma: 高斯滤波的标准差 (默认1.0)
- sigma_color: 双边滤波的颜色空间标准差 (默认75.0)
- sigma_space: 双边滤波的空间标准差 (默认75.0)
Returns:
数据文件路径和头文件路径
"""
print("=" * 50)
print("高光谱图像平滑滤波器")
print("=" * 50)
# 加载图像
print(f"加载图像: {input_path}")
data, header = self.load_image(input_path, band_index)
# 应用滤波器
print(f"应用{filter_type}滤波...")
filtered_data = self.apply_filter(filter_type, **kwargs)
# 保存结果
print(f"保存结果到: {output_path}")
dat_path, hdr_path = self.save_envi(output_path, filtered_data, filter_type, header)
print("=" * 50)
print("处理完成!")
print(f"数据文件: {dat_path}")
print(f"头文件: {hdr_path}")
print("=" * 50)
return dat_path, hdr_path
def main():
"""主函数:命令行接口"""
import argparse
parser = argparse.ArgumentParser(
description='高光谱图像平滑滤波工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
1. 中值滤波:
python Smooth_filter.py input.hdr -f median -b 50 -k 5 -o output.dat
2. 高斯滤波:
python Smooth_filter.py input.hdr -f gaussian -b 30 -k 3 -s 1.5 -o output.dat
3. 双边滤波:
python Smooth_filter.py input.hdr -f bilateral -b 25 -k 5 -sc 75 -ss 75 -o output.dat
4. 均值滤波:
python Smooth_filter.py input.hdr -f mean -b 40 -k 3 -o output.dat
支持的滤波类型:
mean: 均值滤波
median: 中值滤波
gaussian: 高斯滤波
bilateral: 双边滤波
"""
)
parser.add_argument('input_path', help='输入高光谱图像文件路径')
parser.add_argument('-f', '--filter-type', default='gaussian',
choices=['mean', 'median', 'gaussian', 'bilateral'],
help='滤波类型 (默认: gaussian)')
parser.add_argument('-b', '--band-index', type=int, default=0,
help='处理的波段索引 (默认: 0)')
parser.add_argument('-o', '--output', required=True,
help='输出文件路径')
parser.add_argument('-k', '--kernel-size', type=int, default=3,
help='卷积核大小,必须是奇数 (默认: 3)')
# 高斯滤波参数
parser.add_argument('-s', '--sigma', type=float, default=1.0,
help='高斯标准差用于gaussian滤波 (默认: 1.0)')
# 双边滤波参数
parser.add_argument('-sc', '--sigma-color', type=float, default=75.0,
help='颜色空间标准差用于bilateral滤波 (默认: 75.0)')
parser.add_argument('-ss', '--sigma-space', type=float, default=75.0,
help='坐标空间标准差用于bilateral滤波 (默认: 75.0)')
args = parser.parse_args()
try:
print("=" * 60)
print("高光谱图像平滑滤波工具")
print("=" * 60)
print(f"输入文件: {args.input_path}")
print(f"滤波类型: {args.filter_type}")
print(f"波段索引: {args.band_index}")
print(f"卷积核大小: {args.kernel_size}")
print(f"输出文件: {args.output}")
# 根据滤波类型显示额外参数
if args.filter_type == 'gaussian':
print(f"高斯标准差: {args.sigma}")
elif args.filter_type == 'bilateral':
print(f"颜色标准差: {args.sigma_color}")
print(f"空间标准差: {args.sigma_space}")
print()
# 构建参数字典
kwargs = {'kernel_size': args.kernel_size}
if args.filter_type == 'gaussian':
kwargs['sigma'] = args.sigma
elif args.filter_type == 'bilateral':
kwargs['sigma_color'] = args.sigma_color
kwargs['sigma_space'] = args.sigma_space
# 创建滤波器实例并处理图像
filter_obj = HyperspectralImageFilter()
filter_obj.process_image(
input_path=args.input_path,
output_path=args.output,
filter_type=args.filter_type,
band_index=args.band_index,
**kwargs
)
print("\n" + "=" * 60)
print("滤波处理完成!")
print("=" * 60)
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
return 0
@dataclass
class SmoothFilterConfig:
"""平滑滤波配置类"""
input_path: str
filter_type: str
band_index: int
kernel_size: int
output_dir: str
sigma: float = 1.0
sigma_color: float = 75.0
sigma_space: float = 75.0
if __name__ == "__main__":
main() # 运行原始调试
# test_data_preservation() # 测试数据格式保持

View File

@ -0,0 +1,745 @@
import numpy as np
import os
import warnings
from typing import Tuple, List, Dict, Optional, Union
from pathlib import Path
from dataclasses import dataclass
import spectral as spy
from spectral import envi
import matplotlib.pyplot as plt
from skimage import morphology
from scipy import ndimage
import struct
warnings.filterwarnings('ignore')
class HyperspectralMorphologyProcessor:
"""高光谱图像形态学处理类"""
def __init__(self):
self.data = None
self.header = None
self.wavelengths = None
self.shape = None
self.dtype = None
def load_hyperspectral_image(self, file_path: str, format_type: str = None) -> Tuple[np.ndarray, Dict]:
"""
加载高光谱图像支持bil、bip、bsq、dat格式
Parameters:
-----------
file_path : str
图像文件路径(可以是.hdr、.bil、.bip、.bsq、.dat
format_type : str, optional
格式类型('bil''bip''bsq''dat'如果为None则自动检测
Returns:
--------
data : np.ndarray
高光谱数据立方体 (rows, cols, bands)
header : Dict
头文件信息
"""
try:
# 自动检测格式
if format_type is None:
format_type = self._detect_format(file_path)
# 如果是.hdr文件直接使用spectral加载
if file_path.lower().endswith('.hdr'):
img = envi.open(file_path)
data = img.load()
header = dict(img.metadata)
# 如果是其他格式,尝试查找对应的.hdr文件
else:
# 查找可能的hdr文件
hdr_candidates = [
file_path.rsplit('.', 1)[0] + '.hdr',
file_path + '.hdr',
file_path[:-4] + '.hdr' if len(file_path) > 4 else None
]
hdr_file = None
for candidate in hdr_candidates:
if candidate and os.path.exists(candidate):
hdr_file = candidate
break
if hdr_file is None:
raise FileNotFoundError(f"未找到头文件(.hdr)用于 {file_path}")
# 加载图像
img = envi.open(hdr_file)
data = img.load()
header = dict(img.metadata)
# 根据格式调整数据布局
if format_type.lower() == 'bil':
# BIL: 波段交错按行
data = self._convert_to_bil(data)
elif format_type.lower() == 'bip':
# BIP: 波段交错按像素 (spectral默认)
pass # spectral默认就是BIP
elif format_type.lower() == 'bsq':
# BSQ: 波段顺序
data = self._convert_to_bsq(data)
self.data = data
self.header = header
self.shape = data.shape
self.dtype = data.dtype
# 提取波长信息
if 'wavelength' in header:
self.wavelengths = np.array([float(w) for w in header['wavelength']])
else:
self.wavelengths = None
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}, 格式={format_type}")
if self.wavelengths is not None:
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def _detect_format(self, file_path: str) -> str:
"""自动检测图像格式"""
ext = Path(file_path).suffix.lower()
if ext == '.hdr':
# 读取hdr文件内容检测数据格式
with open(file_path, 'r') as f:
content = f.read()
if 'interleave = bil' in content.lower():
return 'bil'
elif 'interleave = bsq' in content.lower():
return 'bsq'
elif 'interleave = bip' in content.lower():
return 'bip'
else:
# 默认假设为BIP
return 'bip'
elif ext in ['.bil', '.bip', '.bsq', '.dat']:
# 从扩展名判断
return ext[1:] # 去掉点号
else:
# 默认假设为BIP
return 'bip'
def _convert_to_bil(self, data: np.ndarray) -> np.ndarray:
"""将数据转换为BIL格式"""
rows, cols, bands = data.shape
bil_data = np.zeros((rows, bands, cols), dtype=data.dtype)
for i in range(rows):
for b in range(bands):
bil_data[i, b, :] = data[i, :, b]
return bil_data
def _convert_to_bsq(self, data: np.ndarray) -> np.ndarray:
"""将数据转换为BSQ格式"""
rows, cols, bands = data.shape
bsq_data = np.zeros((bands, rows, cols), dtype=data.dtype)
for b in range(bands):
bsq_data[b, :, :] = data[:, :, b]
return bsq_data
def extract_band(self, band_index: int) -> np.ndarray:
"""
提取指定波段
Parameters:
-----------
band_index : int
波段索引从0开始
Returns:
--------
band_data : np.ndarray
单波段图像 (rows, cols)
"""
if self.data is None:
raise ValueError("请先加载图像数据")
if band_index < 0 or band_index >= self.shape[2]:
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {self.shape[2]-1}]")
# 根据数据布局提取波段
if len(self.data.shape) == 3:
# 标准形状 (rows, cols, bands)
# 使用squeeze()来确保移除单维度
band_data = np.squeeze(self.data[:, :, band_index])
print(f"提取前形状: {self.data.shape}, band_index: {band_index}")
print(f"提取后初始形状: {band_data.shape}")
# 确保最终是2D
if len(band_data.shape) != 2:
raise ValueError(f"无法提取2D波段数据最终形状: {band_data.shape}")
elif len(self.data.shape) == 2:
# 已经是单波段
band_data = self.data
else:
raise ValueError(f"不支持的数据形状: {self.data.shape}")
print(f"提取波段 {band_index}: 最终形状={band_data.shape}")
return band_data
def apply_morphology_operation(self, band_data: np.ndarray,
operation: str = 'dilation',
se_shape: str = 'disk',
se_size: int = 3,
**kwargs) -> np.ndarray:
"""
应用形态学操作
Parameters:
-----------
band_data : np.ndarray
单波段图像数据 (2D)
operation : str
形态学操作类型:
- 'dilation': 膨胀
- 'erosion': 腐蚀
- 'opening': 开运算
- 'closing': 闭运算
- 'gradient': 形态学梯度 (膨胀 - 腐蚀)
- 'tophat': 顶帽变换 (原图 - 开运算)
- 'bottomhat': 底帽变换 (闭运算 - 原图)
- 'reconstruction': 形态学重建
se_shape : str
结构元素形状: 'disk', 'square', 'rectangle', 'diamond'
se_size : int
结构元素大小
Returns:
--------
result : np.ndarray
处理后的图像
"""
# 确保输入数据是2D的
if len(band_data.shape) != 2:
raise ValueError(f"输入数据必须是2D的当前形状: {band_data.shape}")
# 创建结构元素
selem = self._create_structuring_element(se_shape, se_size)
# 转换为浮点数以保证精度
data_float = band_data.astype(np.float32)
# 应用形态学操作
if operation == 'dilation':
result = ndimage.grey_dilation(data_float, footprint=selem)
elif operation == 'erosion':
result = ndimage.grey_erosion(data_float, footprint=selem)
elif operation == 'opening':
# 开运算: 先腐蚀后膨胀
eroded = ndimage.grey_erosion(data_float, footprint=selem)
result = ndimage.grey_dilation(eroded, footprint=selem)
elif operation == 'closing':
# 闭运算: 先膨胀后腐蚀
dilated = ndimage.grey_dilation(data_float, footprint=selem)
result = ndimage.grey_erosion(dilated, footprint=selem)
elif operation == 'gradient':
# 形态学梯度: 膨胀 - 腐蚀
dilated = ndimage.grey_dilation(data_float, footprint=selem)
eroded = ndimage.grey_erosion(data_float, footprint=selem)
result = dilated - eroded
elif operation == 'tophat':
# 顶帽变换: 原图 - 开运算
eroded = ndimage.grey_erosion(data_float, footprint=selem)
opened = ndimage.grey_dilation(eroded, footprint=selem)
result = data_float - opened
elif operation == 'bottomhat':
# 底帽变换: 闭运算 - 原图
dilated = ndimage.grey_dilation(data_float, footprint=selem)
closed = ndimage.grey_erosion(dilated, footprint=selem)
result = closed - data_float
elif operation == 'reconstruction':
# 形态学重建(基于膨胀的重建)
# 使用标记图像(这里使用腐蚀后的图像作为标记)
marker = ndimage.grey_erosion(data_float, footprint=selem)
result = self._morphological_reconstruction(marker, data_float, selem)
else:
raise ValueError(f"不支持的形态学操作: {operation}")
# 转换为原始数据类型
if operation in ['gradient', 'tophat', 'bottomhat']:
# 这些操作可能产生负值,保持浮点类型
return result
else:
return self._convert_to_original_type(result, band_data.dtype)
def _create_structuring_element(self, shape: str, size: int) -> np.ndarray:
"""创建结构元素"""
if shape == 'disk':
# 创建圆形结构元素
radius = size // 2
y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
mask = x*x + y*y <= radius*radius
selem = np.zeros((2*radius+1, 2*radius+1), dtype=bool)
selem[mask] = True
elif shape == 'square':
# 创建方形结构元素
selem = np.ones((size, size), dtype=bool)
elif shape == 'rectangle':
# 创建矩形结构元素
selem = np.ones((size, size*2), dtype=bool)
elif shape == 'diamond':
# 创建菱形结构元素
selem = morphology.diamond(size)
else:
raise ValueError(f"不支持的结构元素形状: {shape}")
return selem
def _morphological_reconstruction(self, marker: np.ndarray, mask: np.ndarray,
selem: np.ndarray) -> np.ndarray:
"""
形态学重建(基于膨胀)
Parameters:
-----------
marker : np.ndarray
标记图像
mask : np.ndarray
掩模图像
selem : np.ndarray
结构元素
Returns:
--------
reconstructed : np.ndarray
重建后的图像
"""
# 确保标记图像不超过掩模图像
marker = np.minimum(marker, mask)
# 迭代膨胀直到收敛
prev_marker = np.zeros_like(marker)
while not np.array_equal(marker, prev_marker):
prev_marker = marker.copy()
# 条件膨胀:在掩模限制下膨胀
dilated = ndimage.grey_dilation(marker, footprint=selem)
marker = np.minimum(dilated, mask)
return marker
def _convert_to_original_type(self, data: np.ndarray, original_dtype: np.dtype) -> np.ndarray:
"""将数据转换回原始数据类型"""
if np.issubdtype(original_dtype, np.integer):
# 对于整数类型,进行裁剪和取整
data = np.clip(data, np.iinfo(original_dtype).min, np.iinfo(original_dtype).max)
return data.astype(original_dtype)
else:
# 对于浮点类型,直接转换
return data.astype(original_dtype)
def apply_to_multiple_bands(self, band_indices: List[int], operation: str,
se_shape: str = 'disk', se_size: int = 3) -> Dict[int, np.ndarray]:
"""
对多个波段应用形态学操作
Parameters:
-----------
band_indices : List[int]
波段索引列表
operation : str
形态学操作类型
se_shape : str
结构元素形状
se_size : int
结构元素大小
Returns:
--------
results : Dict[int, np.ndarray]
每个波段的处理结果
"""
results = {}
for band_idx in band_indices:
print(f"处理波段 {band_idx}...")
band_data = self.extract_band(band_idx)
# extract_band 现在应该总是返回2D数据但为了安全起见检查一下
if len(band_data.shape) != 2:
raise ValueError(f"波段 {band_idx} 数据不是2D的形状: {band_data.shape}")
result = self.apply_morphology_operation(band_data, operation, se_shape, se_size)
results[band_idx] = result
return results
def save_as_envi(self, data: np.ndarray, output_path: str,
description: str = "形态学处理结果") -> None:
"""
保存为ENVI格式的dat和hdr文件
Parameters:
-----------
data : np.ndarray
要保存的数据单波段2D数组
output_path : str
输出文件路径(不含扩展名)
description : str
图像描述
"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 确保输出文件扩展名为.dat
if not output_path.lower().endswith('.dat'):
dat_path = output_path + '.dat'
else:
dat_path = output_path
# 保存为二进制dat文件
data.tofile(dat_path)
# 创建对应的hdr文件
hdr_path = dat_path.rsplit('.', 1)[0] + '.hdr'
self._create_morphology_hdr_file(hdr_path, data.shape, description, data.dtype)
print(f"ENVI格式已保存:")
print(f" 数据文件: {dat_path}")
print(f" 头文件: {hdr_path}")
print(f" 数据形状: {data.shape}, 数据类型: {data.dtype}")
except Exception as e:
raise IOError(f"保存ENVI文件失败: {e}")
def _create_morphology_hdr_file(self, hdr_path: str, data_shape: Tuple,
description: str, data_dtype: np.dtype = None) -> None:
"""
创建形态学处理结果的ENVI头文件
Parameters:
-----------
hdr_path : str
头文件路径
data_shape : Tuple
数据形状
description : str
图像描述
data_dtype : np.dtype, optional
数据类型如果为None则使用默认值
"""
try:
# 从原始头文件获取基本信息(如果有的话)
if self.header is not None:
# 复制原始头文件的关键信息
samples = self.header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
lines = self.header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
bands = 1 # 形态学处理结果是单波段
interleave = 'bsq'
# 根据数据类型确定ENVI数据类型
if data_dtype is not None:
dtype = data_dtype
else:
dtype = np.float32 # 默认类型
if np.issubdtype(dtype, np.float32):
data_type = 4 # float32
elif np.issubdtype(dtype, np.float64):
data_type = 5 # float64
elif np.issubdtype(dtype, np.int32):
data_type = 3 # int32
elif np.issubdtype(dtype, np.int16):
data_type = 2 # int16
elif np.issubdtype(dtype, np.uint16):
data_type = 12 # uint16
else:
data_type = 4 # 默认float32
byte_order = self.header.get('byte order', 0)
wavelength_units = self.header.get('wavelength units', 'nm')
else:
# 默认值适用于2D形态学处理结果
if len(data_shape) == 2:
lines, samples = data_shape
else:
lines = data_shape[0]
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
bands = 1
interleave = 'bsq'
data_type = 4 # float32形态学处理通常涉及浮点数
byte_order = 0
wavelength_units = 'Unknown'
# 写入hdr文件
with open(hdr_path, 'w') as f:
f.write("ENVI\n")
f.write("description = {\n")
f.write(f" {description}\n")
f.write(" 单波段形态学处理结果\n")
f.write("}\n")
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write(f"header offset = 0\n")
f.write(f"file type = ENVI Standard\n")
f.write(f"data type = {data_type}\n")
f.write(f"interleave = {interleave}\n")
f.write(f"byte order = {byte_order}\n")
# 如果有波长信息,添加虚拟波长
if self.header and 'wavelength' in self.header:
f.write("wavelength = {\n")
f.write(" 形态学处理结果\n")
f.write("}\n")
f.write(f"wavelength units = {wavelength_units}\n")
except Exception as e:
raise IOError(f"创建HDR文件失败: {e}")
def visualize_results(self, original_band: np.ndarray,
processed_band: np.ndarray,
operation_name: str,
save_path: Optional[str] = None) -> None:
"""
可视化原始波段和处理后的波段
Parameters:
-----------
original_band : np.ndarray
原始波段数据
processed_band : np.ndarray
处理后的波段数据
operation_name : str
操作名称
save_path : str, optional
保存图像路径
"""
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 原始图像
im1 = axes[0].imshow(original_band, cmap='gray')
axes[0].set_title('原始波段')
axes[0].axis('off')
plt.colorbar(im1, ax=axes[0], shrink=0.8)
# 处理后的图像
im2 = axes[1].imshow(processed_band, cmap='gray')
axes[1].set_title(f'{operation_name} 结果')
axes[1].axis('off')
plt.colorbar(im2, ax=axes[1], shrink=0.8)
# 差异图像
diff = processed_band.astype(np.float32) - original_band.astype(np.float32)
im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
axes[2].set_title('差异 (处理后 - 原始)')
axes[2].axis('off')
plt.colorbar(im3, ax=axes[2], shrink=0.8)
plt.suptitle(f'形态学操作: {operation_name}', fontsize=16)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"可视化结果已保存到: {save_path}")
plt.show()
def main():
"""主函数:高光谱图像形态学处理"""
import argparse
# 支持的形态学操作列表
morph_operations = [
'dilation', 'erosion', 'opening', 'closing',
'gradient', 'tophat', 'bottomhat', 'reconstruction'
]
# 支持的结构元素形状
se_shapes = ['disk', 'square', 'rectangle', 'diamond']
parser = argparse.ArgumentParser(description='高光谱图像形态学处理工具')
parser.add_argument('input_file', help='输入高光谱图像文件路径')
parser.add_argument('--format', '-f', default='auto',
choices=['auto', 'bil', 'bip', 'bsq', 'dat'],
help='图像格式 (默认: auto)')
parser.add_argument('--band', '-b', type=int, default=0,
help='要处理的波段索引 (默认: 0)')
parser.add_argument('--bands', '-B', type=str, default=None,
help='多个波段索引,用逗号分隔,如 "0,10,20"')
parser.add_argument('--operation', '-o', default='dilation',
choices=morph_operations,
help=f'形态学操作类型 (默认: dilation)')
parser.add_argument('--se_shape', '-s', default='disk',
choices=se_shapes,
help='结构元素形状 (默认: disk)')
parser.add_argument('--se_size', '-S', type=int, default=3,
help='结构元素大小 (默认: 3)')
parser.add_argument('--output_dir', '-d', default='output',
help='输出目录 (默认: output)')
parser.add_argument('--output_name', '-n', default='morphology_result',
help='输出文件名 (不含扩展名) (默认: morphology_result)')
parser.add_argument('--visualize', '-v', action='store_true',
help='是否生成可视化结果')
args = parser.parse_args()
try:
# 初始化处理器
processor = HyperspectralMorphologyProcessor()
# 确定格式
format_type = None if args.format == 'auto' else args.format
# 加载图像
print(f"加载图像: {args.input_file}")
data, header = processor.load_hyperspectral_image(args.input_file, format_type)
# 确定要处理的波段
if args.bands:
band_indices = [int(b.strip()) for b in args.bands.split(',')]
print(f"处理波段: {band_indices}")
single_band = False
else:
band_indices = [args.band]
print(f"处理波段: {args.band}")
single_band = True
# 应用形态学操作
if single_band:
# 单波段处理
band_data = processor.extract_band(args.band)
print(f"主函数中波段数据形状: {band_data.shape}")
result = processor.apply_morphology_operation(
band_data, args.operation, args.se_shape, args.se_size
)
# 保存结果
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{args.band}")
processor.save_as_envi(result, output_path,
f"{args.operation.capitalize()} 处理结果 - 波段 {args.band}")
# 可视化
if args.visualize:
vis_path = os.path.join(args.output_dir, f"visualization_band{args.band}.png")
processor.visualize_results(band_data, result, args.operation, vis_path)
# 打印统计信息
print(f"\n=== 统计信息 ===")
print(f"原始波段 {args.band}: min={band_data.min():.4f}, max={band_data.max():.4f}, "
f"mean={band_data.mean():.4f}")
print(f"处理后: min={result.min():.4f}, max={result.max():.4f}, "
f"mean={result.mean():.4f}")
else:
# 多波段处理
results = processor.apply_to_multiple_bands(
band_indices, args.operation, args.se_shape, args.se_size
)
# 保存每个波段的结果
for band_idx, result in results.items():
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{band_idx}")
processor.save_as_envi(result, output_path,
f"{args.operation.capitalize()} 处理结果 - 波段 {band_idx}")
# 如果只有少量波段,可以创建组合可视化
if len(band_indices) <= 4 and args.visualize:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for idx, band_idx in enumerate(band_indices[:4]):
original_band = processor.extract_band(band_idx)
processed_band = results[band_idx]
axes[idx].imshow(processed_band, cmap='gray')
axes[idx].set_title(f'波段 {band_idx} - {args.operation}')
axes[idx].axis('off')
plt.suptitle(f'多波段形态学处理: {args.operation}', fontsize=16)
plt.tight_layout()
vis_path = os.path.join(args.output_dir, "multi_band_visualization.png")
plt.savefig(vis_path, dpi=300, bbox_inches='tight')
print(f"多波段可视化结果已保存到: {vis_path}")
plt.show()
print(f"\n✓ 形态学处理完成!")
print(f"输出目录: {args.output_dir}")
return 0
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
# 使用示例函数
def run_example():
"""运行示例"""
# 示例1对单个波段进行膨胀操作
processor = HyperspectralMorphologyProcessor()
# 加载图像(请替换为您的实际文件路径)
input_file = "path/to/your/hyperspectral_image.hdr" # 或 .bil, .bip, .bsq, .dat
data, header = processor.load_hyperspectral_image(input_file, format_type='bip')
# 提取波段例如提取第50个波段
band_idx = 50
band_data = processor.extract_band(band_idx)
# 应用不同的形态学操作
operations = ['dilation', 'erosion', 'opening', 'closing',
'gradient', 'tophat', 'bottomhat', 'reconstruction']
results = {}
for operation in operations:
print(f"\n应用 {operation} 操作...")
try:
result = processor.apply_morphology_operation(
band_data, operation, se_shape='disk', se_size=3
)
results[operation] = result
# 保存结果
output_path = f"output/{operation}_band{band_idx}"
processor.save_as_envi(result, output_path,
f"{operation.capitalize()} 处理结果")
# 可视化
processor.visualize_results(band_data, result, operation,
save_path=f"output/visualization_{operation}.png")
except Exception as e:
print(f"操作 {operation} 失败: {e}")
return results
@dataclass
class MorphologicalFilterConfig:
"""形态学滤波配置类"""
input_path: str
operation: str
band_index: int
kernel_size: int
output_dir: str
se_shape: str = 'disk'
se_size: int = 3
format_type: str = 'auto'
if __name__ == '__main__':
exit(main())