第一次提交

This commit is contained in:
2025-12-15 17:41:27 +08:00
commit 317d27d775
17 changed files with 7833 additions and 0 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@ -0,0 +1,35 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="22">
<item index="0" class="java.lang.String" itemvalue="catboost" />
<item index="1" class="java.lang.String" itemvalue="opencv-python" />
<item index="2" class="java.lang.String" itemvalue="xgboost" />
<item index="3" class="java.lang.String" itemvalue="opencv-python-headless" />
<item index="4" class="java.lang.String" itemvalue="zstandard" />
<item index="5" class="java.lang.String" itemvalue="scikit-image" />
<item index="6" class="java.lang.String" itemvalue="scipy" />
<item index="7" class="java.lang.String" itemvalue="tensorflow_gpu" />
<item index="8" class="java.lang.String" itemvalue="h5py" />
<item index="9" class="java.lang.String" itemvalue="matplotlib" />
<item index="10" class="java.lang.String" itemvalue="numpy" />
<item index="11" class="java.lang.String" itemvalue="opencv_python" />
<item index="12" class="java.lang.String" itemvalue="Pillow" />
<item index="13" class="java.lang.String" itemvalue="tensorflow" />
<item index="14" class="java.lang.String" itemvalue="jupyter" />
<item index="15" class="java.lang.String" itemvalue="ipykernel" />
<item index="16" class="java.lang.String" itemvalue="pandas" />
<item index="17" class="java.lang.String" itemvalue="Werkzeug" />
<item index="18" class="java.lang.String" itemvalue="cellpose" />
<item index="19" class="java.lang.String" itemvalue="torchvision" />
<item index="20" class="java.lang.String" itemvalue="Flask" />
<item index="21" class="java.lang.String" itemvalue="fiona" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

6
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="insect" />
</component>
</project>

7
.idea/single_classsfication.iml generated Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<module version="4">
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

209
README.md Normal file
View File

@ -0,0 +1,209 @@
# 高光谱图像分类与分析工具包
[![Python Version](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/)
[![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
一个全面的高光谱图像处理、分类和光谱分析工具包具有GUI就绪架构。
## 🌟 功能特性
### 🔬 光谱分析
- **光谱指数计算**: 计算各种植被、矿物和水体指数
- **自定义公式支持**: 定义和计算自定义光谱指数
- **多格式支持**: 处理CSV、ENVI.hdr/.dat和其他光谱数据格式
### 🤖 机器学习模型
- **传统模型**: 线性回归、岭回归、LASSO、弹性网络、贝叶斯岭回归
- **集成方法**: 随机森林、梯度提升、XGBoost、LightGBM、AdaBoost
- **神经网络**: MLP、LSTM、GRU支持TensorFlow/PyTorch
- **专业模型**: 支持向量回归、高斯过程、KNN
### 🖼️ 图像处理
- **空间滤波**: 均值、中值、高斯和双边滤波
- **多波段支持**: 处理单个波段或整个高光谱数据立方体
- **格式保持**: 保持原始数据类型和数值范围
### 🎯 GUI就绪架构
- **配置驱动**: 集中化参数管理和验证
- **模块化设计**: 业务逻辑与参数输入完全解耦
- **错误处理**: 全面的验证和用户友好的错误消息
## 📋 系统要求
### 系统配置
- Python 3.8 或更高版本
- 4GB+ RAM推荐大型数据集使用8GB+
- GPU 支持(可选,用于深度学习模型)
### 依赖包
核心依赖包(完整列表见 `requirements.txt`
```
numpy>=1.21.0
pandas>=1.3.0
scikit-learn>=1.0.0
matplotlib>=3.5.0
xgboost>=1.5.0
lightgbm>=3.3.0
```
可选依赖包:
- `tensorflow``torch` 用于深度学习模型
- `spectral` 用于ENVI文件处理
- `opencv-python` 用于图像滤波
## 🚀 安装指南
### 1. 克隆仓库
```bash
git clone http://git.iris-rs.cn/zhanghuilai/HSI.git
cd HSI
```
### 2. 创建虚拟环境(推荐)
```bash
python -m venv hyperspectral_env
source hyperspectral_env/bin/activate # Windows: hyperspectral_env\Scripts\activate
```
### 3. 安装依赖包
```bash
pip install -r requirements.txt
```
### 4. 根据需要安装可选依赖包
```bash
# 深度学习支持
pip install tensorflow # 或 torch
# ENVI文件处理
pip install spectral
# 开发环境
pip install pytest jupyter
```
## 📖 使用方法
### 基础用法
```python
from rgression_method.regression import RegressionAnalyzer, RegressionConfig
from spectral_index import HyperspectralIndexCalculator, SpectralIndexConfig
# 1. 回归分析配置驱动
config = RegressionConfig.create_default(
csv_path="your_data.csv",
label_column="target"
)
config.models.model_names = ['linear', 'xgboost', 'lightgbm']
analyzer = RegressionAnalyzer(config)
results = analyzer.run_analysis_from_config()
# 2. 光谱指数计算
index_config = SpectralIndexConfig.create_default()
calculator = HyperspectralIndexCalculator(index_config)
results = calculator.run_analysis_from_config("hyperspectral_data.hdr")
```
### 命令行使用
```bash
# 光谱指数计算
python spectral_index.py input.hdr -i NDVI EVI -o results
# 图像滤波
python fliter/Smooth_filter.py input.hdr -f mean -k 3 -b 20 -o filtered_output
```
### 高级配置
```python
# 自定义回归配置
config = RegressionConfig()
config.data.csv_path = "data.csv"
config.data.label_column = "chlorophyll"
config.data.spectrum_columns = "10:200" # 波长范围
config.models.tune_hyperparams = True
config.models.model_names = ['ridge', 'lasso', 'xgboost']
config.output.save_models = True
analyzer = RegressionAnalyzer(config)
results = analyzer.run_analysis_from_config()
```
## 📁 项目结构
```
hyperspectral-toolkit/
├── rgression_method/
│ ├── regression.py # 主要回归分析模块
│ └── __init__.py
├── fliter/
│ ├── Smooth_filter.py # 图像滤波模块
│ └── morphological_fliter.py # 形态学滤波
├── spectral_index.py # 光谱指数计算器
├── classfication.py # 分类工具
├── cluster.py # 聚类算法
├── requirements.txt # 依赖包列表
├── README.md # 本文件
└── examples/ # 使用示例
```
## 🔧 配置系统
该工具包使用为GUI集成设计的分层配置系统
### 数据配置
```python
@dataclass
class DataConfig:
csv_path: str = ""
label_column: Union[str, int] = ""
spectrum_columns: Optional[Union[str, List]] = None
test_size: float = 0.2
scale_method: str = 'standard'
```
### 模型配置
```python
@dataclass
class ModelConfig:
model_names: Optional[List[str]] = None # None = 所有模型
tune_hyperparams: bool = True
tuning_method: str = 'grid'
cv_folds: int = 5
```
### 训练配置
```python
@dataclass
class TrainingConfig:
epochs: int = 100
batch_size: int = 32
learning_rate: float = 0.001
```
## 🎨 可用模型
### 回归模型
- **线性模型**: 线性回归、岭回归、LASSO、弹性网络、贝叶斯岭回归
- **集成模型**: 随机森林、梯度提升、AdaBoost
- **提升模型**: XGBoost、LightGBM
- **神经网络**: MLP、LSTM、GRU
- **专业模型**: SVR、高斯过程、KNN
### 光谱指数
- **植被指数**: NDVI、EVI、ARVI、SAVI 等
- **水体指数**: NDWI、MNDWI 等
- **矿物指数**: 各种矿物特定指数
- **自定义指数**: 用户定义的公式
## 📊 输出格式
- **CSV**: 包含统计信息的表格结果
- **ENVI**: 标准高光谱格式(.dat + .hdr
- **图像**: PNG图表和可视化
- **模型**: Pickle格式的sklearn模型用于部署

Binary file not shown.

File diff suppressed because it is too large Load Diff

1012
cluster_method/cluster.py Normal file

File diff suppressed because it is too large Load Diff

73
data/roi.xml Normal file
View File

@ -0,0 +1,73 @@
<?xml version="1.0" encoding="UTF-8"?>
<RegionsOfInterest version="1.1">
<Region name="ROI #1" color="255,0,0">
<GeometryDef>
<CoordSysStr>none</CoordSysStr>
<Polygon>
<Exterior>
<LinearRing>
<Coordinates>
293.49999851 419.74999976 319.75 419.74999976 319.75 442.25 293.49999851 442.25 293.49999851 419.74999976
</Coordinates>
</LinearRing>
</Exterior>
</Polygon>
</GeometryDef>
</Region>
<Region name="ROI #2" color="0,128,0">
<GeometryDef>
<CoordSysStr>none</CoordSysStr>
<Polygon>
<Exterior>
<LinearRing>
<Coordinates>
684.74999356 430.99999962 708.5 430.99999962 708.5 448.5 684.74999356 448.5 684.74999356 430.99999962
</Coordinates>
</LinearRing>
</Exterior>
</Polygon>
</GeometryDef>
</Region>
<Region name="ROI #3" color="0,0,255">
<GeometryDef>
<CoordSysStr>none</CoordSysStr>
<Polygon>
<Exterior>
<LinearRing>
<Coordinates>
319.74999818 844.74999438 346 844.74999438 346 862.25 319.74999818 862.25 319.74999818 844.74999438
</Coordinates>
</LinearRing>
</Exterior>
</Polygon>
</GeometryDef>
</Region>
<Region name="ROI #4" color="255,255,0">
<GeometryDef>
<CoordSysStr>none</CoordSysStr>
<Polygon>
<Exterior>
<LinearRing>
<Coordinates>
672.24999372 854.74999425 697.25 854.74999425 697.25 873.5 672.24999372 873.5 672.24999372 854.74999425
</Coordinates>
</LinearRing>
</Exterior>
</Polygon>
</GeometryDef>
</Region>
<Region name="ROI #5" color="0,255,255">
<GeometryDef>
<CoordSysStr>none</CoordSysStr>
<Polygon>
<Exterior>
<LinearRing>
<Coordinates>
573.49999497 800.99999494 589.75 800.99999494 589.75 822.25 573.49999497 822.25 573.49999497 800.99999494
</Coordinates>
</LinearRing>
</Exterior>
</Polygon>
</GeometryDef>
</Region>
</RegionsOfInterest>

View File

@ -0,0 +1,678 @@
import numpy as np
import cv2
import os
from typing import Tuple, Optional, Dict, Union
from scipy import ndimage
import warnings
warnings.filterwarnings('ignore')
try:
import spectral as spy
from spectral import envi
HAS_SPECTRAL = True
except ImportError:
HAS_SPECTRAL = False
print("警告: 未安装spectral库只能处理RGB图像")
class HyperspectralImageFilter:
"""高光谱图像平滑滤波器
使用示例:
# 创建滤波器实例
filter_obj = HyperspectralImageFilter()
# 处理图像 - 均值滤波
dat_path, hdr_path = filter_obj.process_image(
input_path='input.hdr',
output_path='output_mean',
filter_type='mean',
kernel_size=5
)
# 处理图像 - 高斯滤波
dat_path, hdr_path = filter_obj.process_image(
input_path='input.hdr',
output_path='output_gaussian',
filter_type='gaussian',
kernel_size=3,
sigma=2.0
)
# 处理图像 - 双边滤波
dat_path, hdr_path = filter_obj.process_image(
input_path='input.hdr',
output_path='output_bilateral',
filter_type='bilateral',
kernel_size=7,
sigma_color=50.0,
sigma_space=50.0
)
# 或者直接使用apply_filter方法
data, header = filter_obj.load_image('input.hdr')
filtered_data = filter_obj.apply_filter('mean', kernel_size=5)
dat_path, hdr_path = filter_obj.save_envi('output', filtered_data, 'mean', header)
"""
def __init__(self):
self.data = None
self.header = None
self.wavelengths = None
self.is_hyperspectral = False
self._selected_band_index = None # 记录选择的波段索引
self._original_dtype = None # 记录原始数据类型
self._original_range = None # 记录原始数据范围
def load_image(self, image_path: str, band_index: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
"""
加载图像文件
Parameters:
image_path: 图像文件路径 (.hdr/.dat 或 .jpg/.png/.tif)
band_index: 对于高光谱图像,指定要处理的波段索引 (None表示使用所有波段)
Returns:
处理后的图像数据和元数据
"""
try:
# 检查文件扩展名
file_ext = os.path.splitext(image_path)[1].lower()
if file_ext in ['.hdr']:
# ENVI高光谱图像
if not HAS_SPECTRAL:
raise ImportError("需要安装spectral库来处理ENVI格式文件")
return self._load_envi_image(image_path, band_index)
elif file_ext in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']:
# RGB图像
return self._load_rgb_image(image_path)
else:
raise ValueError(f"不支持的文件格式: {file_ext}")
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def _load_envi_image(self, hdr_path: str, band_index: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
"""加载ENVI格式高光谱图像"""
try:
# 读取ENVI文件
img = envi.open(hdr_path)
data = img.load()
header = dict(img.metadata)
# 提取波长信息
wavelengths = None
if 'wavelength' in header:
try:
wavelengths = np.array([float(w) for w in header['wavelength']])
except:
wavelengths = None
# 如果指定了波段,选择特定波段
if band_index is not None:
if band_index < 0 or band_index >= data.shape[2]:
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {data.shape[2]-1}]")
print(f"原始数据形状: {data.shape}")
# 记录选择的波段索引,用于头文件生成
self._selected_band_index = band_index
data = data[:, :, band_index]
print(f"选择波段 {band_index} 后的数据形状: {data.shape}")
# 如果结果是形状 (H, W, 1),压缩为 (H, W)
if len(data.shape) == 3 and data.shape[2] == 1:
data = np.squeeze(data, axis=2) # 压缩最后一个维度
print(f"压缩单波段数据为2D: {data.shape}")
print(f"选择波段 {band_index} 进行处理")
if wavelengths is not None:
print(f"波长: {wavelengths[band_index]:.1f} nm")
# 选择了特定波段数据变为2D
self.is_hyperspectral = False
else:
print("处理所有波段")
self.is_hyperspectral = True
self._selected_band_index = None
# 记录原始数据类型和范围
self._original_dtype = data.dtype
if data.size > 0: # 确保数据不为空
self._original_range = (data.min(), data.max())
print(f"原始数据范围: {self._original_range[0]} - {self._original_range[1]}")
self.data = data
self.header = header
self.wavelengths = wavelengths
print(f"成功加载ENVI图像: 形状={data.shape}, 数据类型={data.dtype}")
if wavelengths is not None and band_index is not None:
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载ENVI图像失败: {e}")
def _load_rgb_image(self, image_path: str) -> Tuple[np.ndarray, Dict]:
"""加载RGB图像"""
try:
# 使用OpenCV读取图像
data = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
if data is None:
raise ValueError(f"无法读取图像文件: {image_path}")
# 转换为RGB格式 (OpenCV默认是BGR)
if len(data.shape) == 3 and data.shape[2] == 3:
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
# 记录原始数据类型和范围
self._original_dtype = data.dtype
if data.size > 0:
self._original_range = (data.min(), data.max())
print(f"原始数据范围: {self._original_range[0]} - {self._original_range[1]}")
# 创建基本的元数据
header = {
'samples': data.shape[1],
'lines': data.shape[0],
'bands': data.shape[2] if len(data.shape) == 3 else 1,
'data_type': self._get_envi_data_type(data.dtype),
'interleave': 'bsq',
'byte_order': 0
}
self.data = data
self.header = header
self.wavelengths = None
self.is_hyperspectral = False
print(f"成功加载RGB图像: 形状={data.shape}, 数据类型={data.dtype}")
return data, header
except Exception as e:
raise IOError(f"加载RGB图像失败: {e}")
def _get_envi_data_type(self, dtype: np.dtype) -> int:
"""获取ENVI数据类型代码"""
if dtype == np.uint8:
return 1
elif dtype == np.int16:
return 2
elif dtype == np.int32:
return 3
elif dtype == np.float32:
return 4
elif dtype == np.float64:
return 5
elif dtype == np.uint16:
return 12
elif dtype == np.uint32:
return 13
else:
return 4 # 默认float32
def apply_filter(self, filter_type: str, **kwargs) -> np.ndarray:
"""
应用指定的滤波器
Parameters:
filter_type: 滤波器类型 ('mean', 'median', 'gaussian', 'bilateral')
**kwargs: 各滤波器的超参数
- kernel_size: 内核大小 (奇数默认3)
- sigma: 高斯滤波的标准差 (默认1.0)
- sigma_color: 双边滤波的颜色空间标准差 (默认75.0)
- sigma_space: 双边滤波的空间标准差 (默认75.0)
Returns:
滤波后的图像数据
"""
if self.data is None:
raise ValueError("请先加载图像数据")
# 获取参数,设置默认值
kernel_size = kwargs.get('kernel_size', 3)
sigma = kwargs.get('sigma', 1.0)
sigma_color = kwargs.get('sigma_color', 75.0)
sigma_space = kwargs.get('sigma_space', 75.0)
if kernel_size % 2 == 0:
kernel_size += 1 # 确保为奇数
print(f"应用{filter_type}滤波器,参数: kernel_size={kernel_size}")
if filter_type.lower() == 'mean':
return self._mean_filter(kernel_size)
elif filter_type.lower() == 'median':
return self._median_filter(kernel_size)
elif filter_type.lower() == 'gaussian':
print(f" sigma={sigma}")
return self._gaussian_filter(kernel_size, sigma)
elif filter_type.lower() == 'bilateral':
print(f" sigma_color={sigma_color}, sigma_space={sigma_space}")
return self._bilateral_filter(kernel_size, sigma_color, sigma_space)
else:
raise ValueError(f"不支持的滤波器类型: {filter_type}")
def _mean_filter(self, kernel_size: int) -> np.ndarray:
"""均值滤波"""
try:
print(f"均值滤波开始 - 数据形状: {self.data.shape}, is_hyperspectral: {self.is_hyperspectral}")
# 创建均值内核
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) / (kernel_size * kernel_size)
# 检查数据维度
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
# 多波段高光谱图像 - 对每个波段分别处理
filtered_data = np.zeros_like(self.data, dtype=np.float32)
for band in range(self.data.shape[2]):
filtered_data[:, :, band] = ndimage.convolve(self.data[:, :, band].astype(np.float32), kernel)
print(f"均值滤波完成 - 处理了 {self.data.shape[2]} 个波段")
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
# 单波段3D数据 (H, W, 1) - 压缩并处理
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
filtered_data = ndimage.convolve(data_2d, kernel)
# 保持输出为3D形状以保持一致性
filtered_data = filtered_data[:, :, np.newaxis]
print(f"均值滤波完成 - 单波段3D数据 (H,W,1) -> (H,W,1)")
else:
# 2D图像或单波段图像
print(f"应用2D均值滤波到形状 {self.data.shape} 的数据")
filtered_data = ndimage.convolve(self.data.astype(np.float32), kernel)
print(f"2D均值滤波完成 - 输出形状: {filtered_data.shape}")
return filtered_data
except Exception as e:
raise RuntimeError(f"均值滤波失败: {e}")
def _median_filter(self, kernel_size: int) -> np.ndarray:
"""中值滤波"""
try:
# 检查数据维度
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
# 多波段高光谱图像 - 对每个波段分别处理
filtered_data = np.zeros_like(self.data, dtype=self.data.dtype)
for band in range(self.data.shape[2]):
filtered_data[:, :, band] = ndimage.median_filter(self.data[:, :, band], size=kernel_size)
print(f"中值滤波完成 - 处理了 {self.data.shape[2]} 个波段")
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
# 单波段3D数据 (H, W, 1) - 压缩并处理
data_2d = np.squeeze(self.data, axis=2)
filtered_data = ndimage.median_filter(data_2d, size=kernel_size)
# 保持输出为3D形状以保持一致性
filtered_data = filtered_data[:, :, np.newaxis]
print(f"中值滤波完成 - 单波段3D数据 (H,W,1) -> (H,W,1)")
else:
# 2D图像或单波段图像
filtered_data = ndimage.median_filter(self.data, size=kernel_size)
return filtered_data
except Exception as e:
raise RuntimeError(f"中值滤波失败: {e}")
def _gaussian_filter(self, kernel_size: int, sigma: float) -> np.ndarray:
"""高斯滤波"""
try:
# 检查数据维度
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
# 多波段高光谱图像 - 对每个波段分别处理
filtered_data = np.zeros_like(self.data, dtype=np.float32)
for band in range(self.data.shape[2]):
filtered_data[:, :, band] = ndimage.gaussian_filter(
self.data[:, :, band].astype(np.float32), sigma=sigma
)
print(f"高斯滤波完成 (sigma={sigma}) - 处理了 {self.data.shape[2]} 个波段")
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
# 单波段3D数据 (H, W, 1) - 压缩并处理
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
filtered_data = ndimage.gaussian_filter(data_2d, sigma=sigma)
# 保持输出为3D形状以保持一致性
filtered_data = filtered_data[:, :, np.newaxis]
print(f"高斯滤波完成 (sigma={sigma}) - 单波段3D数据 (H,W,1) -> (H,W,1)")
else:
# 2D图像或单波段图像
filtered_data = ndimage.gaussian_filter(self.data.astype(np.float32), sigma=sigma)
print(f"高斯滤波完成 (sigma={sigma})")
return filtered_data
except Exception as e:
raise RuntimeError(f"高斯滤波失败: {e}")
def _bilateral_filter(self, kernel_size: int, sigma_color: float, sigma_space: float) -> np.ndarray:
"""双边滤波"""
try:
# 双边滤波主要用于RGB图像或单波段图像
# 对于高光谱数据,我们可以对每个波段分别应用双边滤波
if len(self.data.shape) == 3 and self.data.shape[2] > 1 and self.is_hyperspectral:
# 多波段高光谱图像 - 对每个波段分别处理
filtered_data = np.zeros_like(self.data, dtype=np.float32)
for band in range(self.data.shape[2]):
band_data = self.data[:, :, band].astype(np.float32)
# 归一化到0-255范围进行双边滤波
if band_data.max() > band_data.min():
band_norm = ((band_data - band_data.min()) / (band_data.max() - band_data.min()) * 255).astype(np.uint8)
filtered_norm = cv2.bilateralFilter(band_norm, kernel_size, sigma_color, sigma_space)
# 恢复原始范围
filtered_data[:, :, band] = filtered_norm.astype(np.float32) / 255 * (band_data.max() - band_data.min()) + band_data.min()
else:
filtered_data[:, :, band] = band_data
print(f"双边滤波完成 - 处理了 {self.data.shape[2]} 个波段")
elif len(self.data.shape) == 3 and self.data.shape[2] == 3:
# RGB图像 - 直接使用OpenCV的双边滤波
filtered_data = cv2.bilateralFilter(self.data, kernel_size, sigma_color, sigma_space)
print("RGB双边滤波完成")
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
# 单波段3D数据 (H, W, 1) - 压缩并处理
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
if data_2d.max() > data_2d.min():
data_norm = ((data_2d - data_2d.min()) / (data_2d.max() - data_2d.min()) * 255).astype(np.uint8)
filtered_norm = cv2.bilateralFilter(data_norm, kernel_size, sigma_color, sigma_space)
filtered_data = filtered_norm.astype(np.float32) / 255 * (data_2d.max() - data_2d.min()) + data_2d.min()
# 保持输出为3D形状以保持一致性
filtered_data = filtered_data[:, :, np.newaxis]
else:
filtered_data = self.data.copy()
print("单波段双边滤波完成 (3D -> 3D)")
else:
# 2D单波段图像 - 转换为uint8进行处理
data_float = self.data.astype(np.float32)
if data_float.max() > data_float.min():
data_norm = ((data_float - data_float.min()) / (data_float.max() - data_float.min()) * 255).astype(np.uint8)
filtered_norm = cv2.bilateralFilter(data_norm, kernel_size, sigma_color, sigma_space)
filtered_data = filtered_norm.astype(np.float32) / 255 * (data_float.max() - data_float.min()) + data_float.min()
else:
filtered_data = data_float
print("单波段双边滤波完成")
return filtered_data
except Exception as e:
raise RuntimeError(f"双边滤波失败: {e}")
def save_envi(self, output_path: str, filtered_data: np.ndarray, filter_type: str,
original_header: Dict = None) -> Tuple[str, str]:
"""
保存滤波结果为ENVI格式
Parameters:
output_path: 输出文件路径(不含扩展名)
filtered_data: 滤波后的数据
filter_type: 滤波器类型,用于文件名
original_header: 原始图像的头文件信息
Returns:
数据文件路径和头文件路径
"""
try:
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# 生成文件名
base_name = os.path.basename(output_path)
dat_path = f"{output_path}_{filter_type}.dat"
hdr_path = f"{output_path}_{filter_type}.hdr"
# 保存数据文件 - 保持与原始数据一致的类型和范围
print(f"滤波前数据范围: {filtered_data.min():.2f} - {filtered_data.max():.2f}")
print(f"原始数据类型: {self._original_dtype}, 范围: {self._original_range}")
# 根据原始数据类型决定保存格式
if self._original_dtype is not None:
# 如果原始数据是整数类型,尝试保持整数格式
if np.issubdtype(self._original_dtype, np.integer):
# 对于整数类型的原始数据,检查滤波后的数据是否仍在合理范围内
if self._original_range is not None:
orig_min, orig_max = self._original_range
filtered_min, filtered_max = filtered_data.min(), filtered_data.max()
# 如果滤波后的数据范围仍在原始范围附近,使用原始数据类型
if (filtered_min >= orig_min - 100) and (filtered_max <= orig_max + 100):
print(f"保持原始整数数据类型: {self._original_dtype}")
filtered_data = np.clip(filtered_data, orig_min, orig_max).astype(self._original_dtype)
else:
print(f"数据范围变化较大使用float32保存")
filtered_data = filtered_data.astype(np.float32)
else:
filtered_data = filtered_data.astype(np.float32)
else:
# 原始数据就是浮点数保持float32
filtered_data = filtered_data.astype(np.float32)
else:
# 默认使用float32
filtered_data = filtered_data.astype(np.float32)
print(f"保存数据类型: {filtered_data.dtype}, 范围: {filtered_data.min():.2f} - {filtered_data.max():.2f}")
# 确保数据格式适合ENVI标准
if len(filtered_data.shape) == 2:
# 2D数据 (H, W) -> 3D数据 (H, W, 1)
filtered_data = filtered_data[:, :, np.newaxis]
print(f"转换2D数据为3D格式以符合ENVI标准: {filtered_data.shape}")
# ENVI使用BSQ格式存储
filtered_data.tofile(dat_path)
print(f"数据文件已保存: {dat_path}")
# 生成头文件
self._save_envi_header(hdr_path, filtered_data, filter_type, original_header)
print(f"头文件已保存: {hdr_path}")
return dat_path, hdr_path
except Exception as e:
raise IOError(f"保存ENVI文件失败: {e}")
def _save_envi_header(self, hdr_path: str, data: np.ndarray, filter_type: str,
original_header: Dict = None):
"""保存ENVI标准格式头文件"""
try:
from datetime import datetime
with open(hdr_path, 'w', encoding='utf-8') as f:
f.write("ENVI\n")
# 描述信息 - 包含处理时间和描述
current_time = datetime.now().strftime("%a %b %d %H:%M:%S %Y")
f.write("description = {\n")
f.write(f" Convolution Result [{current_time}]\n")
f.write("}\n")
# 图像尺寸
if len(data.shape) == 3:
samples, lines, bands = data.shape[1], data.shape[0], data.shape[2]
else:
samples, lines, bands = data.shape[1], data.shape[0], 1
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write("header offset = 0\n")
f.write("file type = ENVI Standard\n")
# 数据类型 - 使用实际保存的数据类型
if hasattr(self, '_original_dtype') and self._original_dtype is not None:
# 优先使用原始数据类型(如果范围合适)
data_type = self._get_envi_data_type(self._original_dtype)
else:
data_type = self._get_envi_data_type(data.dtype)
f.write(f"data type = {data_type}\n")
# 交织格式
f.write("interleave = bsq\n")
# 传感器类型
if original_header and 'sensor type' in original_header:
f.write(f"sensor type = {original_header['sensor type']}\n")
else:
f.write("sensor type = Unknown\n")
# 字节顺序
f.write("byte order = 0\n")
# 波长信息
if self.wavelengths is not None and len(self.wavelengths) > 0:
f.write("wavelength units = Nanometers\n")
if bands == 1 and hasattr(self, '_selected_band_index') and self._selected_band_index is not None:
# 单波段情况 - 使用选定波段的波长
band_idx = min(self._selected_band_index, len(self.wavelengths) - 1)
f.write(f"wavelength = {{\n")
f.write(f" {self.wavelengths[band_idx]:.6f}\n")
f.write("}\n")
else:
# 多波段情况 - 写入所有波长
f.write("wavelength = {\n")
wavelength_str = ", ".join([f"{w:.6f}" for w in self.wavelengths[:bands]])
f.write(f" {wavelength_str}\n")
f.write("}\n")
# 反射率比例因子
f.write("reflectance scale factor = 10000.000000\n")
# 从原始头文件复制其他重要参数
if original_header:
# 增益信息
if 'gain' in original_header:
f.write(f"gain = {original_header['gain']}\n")
# 分辨率信息
binning_keys = ['sample binning', 'spectral binning', 'line binning']
for key in binning_keys:
if key in original_header:
f.write(f"{key} = {original_header[key]}\n")
# 快门和帧率信息
if 'shutter' in original_header:
f.write(f"shutter = {original_header['shutter']}\n")
if 'framerate' in original_header:
f.write(f"framerate = {original_header['framerate']}\n")
# 设备序列号
if 'imager serial number' in original_header:
f.write(f"imager serial number = {original_header['imager serial number']}\n")
# 旋转矩阵
if 'rotation' in original_header:
rotation = original_header['rotation']
f.write(f"rotation = {rotation}\n")
# 标签
if 'label' in original_header:
f.write(f"label = {original_header['label']}\n")
# 处理历史
if 'history' in original_header:
f.write(f"history = {original_header['history']}\n")
except Exception as e:
raise IOError(f"保存头文件失败: {e}")
def process_image(self, input_path: str, output_path: str, filter_type: str,
band_index: Optional[int] = None, **kwargs) -> Tuple[str, str]:
"""
完整的图像处理流程
Parameters:
input_path: 输入图像路径
output_path: 输出路径(不含扩展名)
filter_type: 滤波器类型
band_index: 波段索引(可选)
**kwargs: 各滤波器的超参数
- kernel_size: 内核大小 (奇数默认3)
- sigma: 高斯滤波的标准差 (默认1.0)
- sigma_color: 双边滤波的颜色空间标准差 (默认75.0)
- sigma_space: 双边滤波的空间标准差 (默认75.0)
Returns:
数据文件路径和头文件路径
"""
print("=" * 50)
print("高光谱图像平滑滤波器")
print("=" * 50)
# 加载图像
print(f"加载图像: {input_path}")
data, header = self.load_image(input_path, band_index)
# 应用滤波器
print(f"应用{filter_type}滤波...")
filtered_data = self.apply_filter(filter_type, **kwargs)
# 保存结果
print(f"保存结果到: {output_path}")
dat_path, hdr_path = self.save_envi(output_path, filtered_data, filter_type, header)
print("=" * 50)
print("处理完成!")
print(f"数据文件: {dat_path}")
print(f"头文件: {hdr_path}")
print("=" * 50)
return dat_path, hdr_path
def main():
"""主函数 - 命令行接口"""
input = r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr"
output = r"E:\code\spectronon\single_classsfication\fliter"
filter_type = 'median'
band = 20
kernel_size = 3
sigma = 1.0
sigma_color = 75.0
sigma_space = 75.0
# 构建超参数字典
kwargs = {
'kernel_size': kernel_size,
'sigma': sigma,
'sigma_color': sigma_color,
'sigma_space': sigma_space
}
try:
# 创建滤波器实例
filter_obj = HyperspectralImageFilter()
# 处理图像
filter_obj.process_image(
input_path=input,
output_path=output,
filter_type=filter_type,
band_index=band,
**kwargs
)
except Exception as e:
print(f"错误: {e}")
return 1
return 0
if __name__ == "__main__":
main() # 运行原始调试
# test_data_preservation() # 测试数据格式保持

View File

@ -0,0 +1,758 @@
import numpy as np
import os
import warnings
from typing import Tuple, List, Dict, Optional, Union
from pathlib import Path
import spectral as spy
from spectral import envi
import matplotlib.pyplot as plt
from skimage import morphology
from scipy import ndimage
import struct
warnings.filterwarnings('ignore')
class HyperspectralMorphologyProcessor:
"""高光谱图像形态学处理类"""
def __init__(self):
self.data = None
self.header = None
self.wavelengths = None
self.shape = None
self.dtype = None
def load_hyperspectral_image(self, file_path: str, format_type: str = None) -> Tuple[np.ndarray, Dict]:
"""
加载高光谱图像支持bil、bip、bsq、dat格式
Parameters:
-----------
file_path : str
图像文件路径(可以是.hdr、.bil、.bip、.bsq、.dat
format_type : str, optional
格式类型('bil''bip''bsq''dat'如果为None则自动检测
Returns:
--------
data : np.ndarray
高光谱数据立方体 (rows, cols, bands)
header : Dict
头文件信息
"""
try:
# 自动检测格式
if format_type is None:
format_type = self._detect_format(file_path)
# 如果是.hdr文件直接使用spectral加载
if file_path.lower().endswith('.hdr'):
img = envi.open(file_path)
data = img.load()
header = dict(img.metadata)
# 如果是其他格式,尝试查找对应的.hdr文件
else:
# 查找可能的hdr文件
hdr_candidates = [
file_path.rsplit('.', 1)[0] + '.hdr',
file_path + '.hdr',
file_path[:-4] + '.hdr' if len(file_path) > 4 else None
]
hdr_file = None
for candidate in hdr_candidates:
if candidate and os.path.exists(candidate):
hdr_file = candidate
break
if hdr_file is None:
raise FileNotFoundError(f"未找到头文件(.hdr)用于 {file_path}")
# 加载图像
img = envi.open(hdr_file)
data = img.load()
header = dict(img.metadata)
# 根据格式调整数据布局
if format_type.lower() == 'bil':
# BIL: 波段交错按行
data = self._convert_to_bil(data)
elif format_type.lower() == 'bip':
# BIP: 波段交错按像素 (spectral默认)
pass # spectral默认就是BIP
elif format_type.lower() == 'bsq':
# BSQ: 波段顺序
data = self._convert_to_bsq(data)
self.data = data
self.header = header
self.shape = data.shape
self.dtype = data.dtype
# 提取波长信息
if 'wavelength' in header:
self.wavelengths = np.array([float(w) for w in header['wavelength']])
else:
self.wavelengths = None
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}, 格式={format_type}")
if self.wavelengths is not None:
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def _detect_format(self, file_path: str) -> str:
"""自动检测图像格式"""
ext = Path(file_path).suffix.lower()
if ext == '.hdr':
# 读取hdr文件内容检测数据格式
with open(file_path, 'r') as f:
content = f.read()
if 'interleave = bil' in content.lower():
return 'bil'
elif 'interleave = bsq' in content.lower():
return 'bsq'
elif 'interleave = bip' in content.lower():
return 'bip'
else:
# 默认假设为BIP
return 'bip'
elif ext in ['.bil', '.bip', '.bsq', '.dat']:
# 从扩展名判断
return ext[1:] # 去掉点号
else:
# 默认假设为BIP
return 'bip'
def _convert_to_bil(self, data: np.ndarray) -> np.ndarray:
"""将数据转换为BIL格式"""
rows, cols, bands = data.shape
bil_data = np.zeros((rows, bands, cols), dtype=data.dtype)
for i in range(rows):
for b in range(bands):
bil_data[i, b, :] = data[i, :, b]
return bil_data
def _convert_to_bsq(self, data: np.ndarray) -> np.ndarray:
"""将数据转换为BSQ格式"""
rows, cols, bands = data.shape
bsq_data = np.zeros((bands, rows, cols), dtype=data.dtype)
for b in range(bands):
bsq_data[b, :, :] = data[:, :, b]
return bsq_data
def extract_band(self, band_index: int) -> np.ndarray:
"""
提取指定波段
Parameters:
-----------
band_index : int
波段索引从0开始
Returns:
--------
band_data : np.ndarray
单波段图像 (rows, cols)
"""
if self.data is None:
raise ValueError("请先加载图像数据")
if band_index < 0 or band_index >= self.shape[2]:
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {self.shape[2]-1}]")
# 根据数据布局提取波段
if len(self.data.shape) == 3:
# 标准形状 (rows, cols, bands)
# 使用squeeze()来确保移除单维度
band_data = np.squeeze(self.data[:, :, band_index])
print(f"提取前形状: {self.data.shape}, band_index: {band_index}")
print(f"提取后初始形状: {band_data.shape}")
# 确保最终是2D
if len(band_data.shape) != 2:
raise ValueError(f"无法提取2D波段数据最终形状: {band_data.shape}")
elif len(self.data.shape) == 2:
# 已经是单波段
band_data = self.data
else:
raise ValueError(f"不支持的数据形状: {self.data.shape}")
print(f"提取波段 {band_index}: 最终形状={band_data.shape}")
return band_data
def apply_morphology_operation(self, band_data: np.ndarray,
operation: str = 'dilation',
se_shape: str = 'disk',
se_size: int = 3,
**kwargs) -> np.ndarray:
"""
应用形态学操作
Parameters:
-----------
band_data : np.ndarray
单波段图像数据 (2D)
operation : str
形态学操作类型:
- 'dilation': 膨胀
- 'erosion': 腐蚀
- 'opening': 开运算
- 'closing': 闭运算
- 'gradient': 形态学梯度 (膨胀 - 腐蚀)
- 'tophat': 顶帽变换 (原图 - 开运算)
- 'bottomhat': 底帽变换 (闭运算 - 原图)
- 'reconstruction': 形态学重建
se_shape : str
结构元素形状: 'disk', 'square', 'rectangle', 'diamond'
se_size : int
结构元素大小
Returns:
--------
result : np.ndarray
处理后的图像
"""
# 确保输入数据是2D的
if len(band_data.shape) != 2:
raise ValueError(f"输入数据必须是2D的当前形状: {band_data.shape}")
# 创建结构元素
selem = self._create_structuring_element(se_shape, se_size)
# 转换为浮点数以保证精度
data_float = band_data.astype(np.float32)
# 应用形态学操作
if operation == 'dilation':
result = ndimage.grey_dilation(data_float, footprint=selem)
elif operation == 'erosion':
result = ndimage.grey_erosion(data_float, footprint=selem)
elif operation == 'opening':
# 开运算: 先腐蚀后膨胀
eroded = ndimage.grey_erosion(data_float, footprint=selem)
result = ndimage.grey_dilation(eroded, footprint=selem)
elif operation == 'closing':
# 闭运算: 先膨胀后腐蚀
dilated = ndimage.grey_dilation(data_float, footprint=selem)
result = ndimage.grey_erosion(dilated, footprint=selem)
elif operation == 'gradient':
# 形态学梯度: 膨胀 - 腐蚀
dilated = ndimage.grey_dilation(data_float, footprint=selem)
eroded = ndimage.grey_erosion(data_float, footprint=selem)
result = dilated - eroded
elif operation == 'tophat':
# 顶帽变换: 原图 - 开运算
eroded = ndimage.grey_erosion(data_float, footprint=selem)
opened = ndimage.grey_dilation(eroded, footprint=selem)
result = data_float - opened
elif operation == 'bottomhat':
# 底帽变换: 闭运算 - 原图
dilated = ndimage.grey_dilation(data_float, footprint=selem)
closed = ndimage.grey_erosion(dilated, footprint=selem)
result = closed - data_float
elif operation == 'reconstruction':
# 形态学重建(基于膨胀的重建)
# 使用标记图像(这里使用腐蚀后的图像作为标记)
marker = ndimage.grey_erosion(data_float, footprint=selem)
result = self._morphological_reconstruction(marker, data_float, selem)
else:
raise ValueError(f"不支持的形态学操作: {operation}")
# 转换为原始数据类型
if operation in ['gradient', 'tophat', 'bottomhat']:
# 这些操作可能产生负值,保持浮点类型
return result
else:
return self._convert_to_original_type(result, band_data.dtype)
def _create_structuring_element(self, shape: str, size: int) -> np.ndarray:
"""创建结构元素"""
if shape == 'disk':
# 创建圆形结构元素
radius = size // 2
y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
mask = x*x + y*y <= radius*radius
selem = np.zeros((2*radius+1, 2*radius+1), dtype=bool)
selem[mask] = True
elif shape == 'square':
# 创建方形结构元素
selem = np.ones((size, size), dtype=bool)
elif shape == 'rectangle':
# 创建矩形结构元素
selem = np.ones((size, size*2), dtype=bool)
elif shape == 'diamond':
# 创建菱形结构元素
selem = morphology.diamond(size)
else:
raise ValueError(f"不支持的结构元素形状: {shape}")
return selem
def _morphological_reconstruction(self, marker: np.ndarray, mask: np.ndarray,
selem: np.ndarray) -> np.ndarray:
"""
形态学重建(基于膨胀)
Parameters:
-----------
marker : np.ndarray
标记图像
mask : np.ndarray
掩模图像
selem : np.ndarray
结构元素
Returns:
--------
reconstructed : np.ndarray
重建后的图像
"""
# 确保标记图像不超过掩模图像
marker = np.minimum(marker, mask)
# 迭代膨胀直到收敛
prev_marker = np.zeros_like(marker)
while not np.array_equal(marker, prev_marker):
prev_marker = marker.copy()
# 条件膨胀:在掩模限制下膨胀
dilated = ndimage.grey_dilation(marker, footprint=selem)
marker = np.minimum(dilated, mask)
return marker
def _convert_to_original_type(self, data: np.ndarray, original_dtype: np.dtype) -> np.ndarray:
"""将数据转换回原始数据类型"""
if np.issubdtype(original_dtype, np.integer):
# 对于整数类型,进行裁剪和取整
data = np.clip(data, np.iinfo(original_dtype).min, np.iinfo(original_dtype).max)
return data.astype(original_dtype)
else:
# 对于浮点类型,直接转换
return data.astype(original_dtype)
def apply_to_multiple_bands(self, band_indices: List[int], operation: str,
se_shape: str = 'disk', se_size: int = 3) -> Dict[int, np.ndarray]:
"""
对多个波段应用形态学操作
Parameters:
-----------
band_indices : List[int]
波段索引列表
operation : str
形态学操作类型
se_shape : str
结构元素形状
se_size : int
结构元素大小
Returns:
--------
results : Dict[int, np.ndarray]
每个波段的处理结果
"""
results = {}
for band_idx in band_indices:
print(f"处理波段 {band_idx}...")
band_data = self.extract_band(band_idx)
# extract_band 现在应该总是返回2D数据但为了安全起见检查一下
if len(band_data.shape) != 2:
raise ValueError(f"波段 {band_idx} 数据不是2D的形状: {band_data.shape}")
result = self.apply_morphology_operation(band_data, operation, se_shape, se_size)
results[band_idx] = result
return results
def save_as_envi(self, data: np.ndarray, output_path: str,
description: str = "形态学处理结果") -> None:
"""
保存为ENVI格式的dat和hdr文件
Parameters:
-----------
data : np.ndarray
要保存的数据单波段2D数组
output_path : str
输出文件路径(不含扩展名)
description : str
图像描述
"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 确保输出文件扩展名为.dat
if not output_path.lower().endswith('.dat'):
dat_path = output_path + '.dat'
else:
dat_path = output_path
# 保存为二进制dat文件
data.tofile(dat_path)
# 创建对应的hdr文件
hdr_path = dat_path.rsplit('.', 1)[0] + '.hdr'
self._create_morphology_hdr_file(hdr_path, data.shape, description, data.dtype)
print(f"ENVI格式已保存:")
print(f" 数据文件: {dat_path}")
print(f" 头文件: {hdr_path}")
print(f" 数据形状: {data.shape}, 数据类型: {data.dtype}")
except Exception as e:
raise IOError(f"保存ENVI文件失败: {e}")
def _create_morphology_hdr_file(self, hdr_path: str, data_shape: Tuple,
description: str, data_dtype: np.dtype = None) -> None:
"""
创建形态学处理结果的ENVI头文件
Parameters:
-----------
hdr_path : str
头文件路径
data_shape : Tuple
数据形状
description : str
图像描述
data_dtype : np.dtype, optional
数据类型如果为None则使用默认值
"""
try:
# 从原始头文件获取基本信息(如果有的话)
if self.header is not None:
# 复制原始头文件的关键信息
samples = self.header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
lines = self.header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
bands = 1 # 形态学处理结果是单波段
interleave = 'bsq'
# 根据数据类型确定ENVI数据类型
if data_dtype is not None:
dtype = data_dtype
else:
dtype = np.float32 # 默认类型
if np.issubdtype(dtype, np.float32):
data_type = 4 # float32
elif np.issubdtype(dtype, np.float64):
data_type = 5 # float64
elif np.issubdtype(dtype, np.int32):
data_type = 3 # int32
elif np.issubdtype(dtype, np.int16):
data_type = 2 # int16
elif np.issubdtype(dtype, np.uint16):
data_type = 12 # uint16
else:
data_type = 4 # 默认float32
byte_order = self.header.get('byte order', 0)
wavelength_units = self.header.get('wavelength units', 'nm')
else:
# 默认值适用于2D形态学处理结果
if len(data_shape) == 2:
lines, samples = data_shape
else:
lines = data_shape[0]
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
bands = 1
interleave = 'bsq'
data_type = 4 # float32形态学处理通常涉及浮点数
byte_order = 0
wavelength_units = 'Unknown'
# 写入hdr文件
with open(hdr_path, 'w') as f:
f.write("ENVI\n")
f.write("description = {\n")
f.write(f" {description}\n")
f.write(" 单波段形态学处理结果\n")
f.write("}\n")
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write(f"header offset = 0\n")
f.write(f"file type = ENVI Standard\n")
f.write(f"data type = {data_type}\n")
f.write(f"interleave = {interleave}\n")
f.write(f"byte order = {byte_order}\n")
# 如果有波长信息,添加虚拟波长
if self.header and 'wavelength' in self.header:
f.write("wavelength = {\n")
f.write(" 形态学处理结果\n")
f.write("}\n")
f.write(f"wavelength units = {wavelength_units}\n")
except Exception as e:
raise IOError(f"创建HDR文件失败: {e}")
def visualize_results(self, original_band: np.ndarray,
processed_band: np.ndarray,
operation_name: str,
save_path: Optional[str] = None) -> None:
"""
可视化原始波段和处理后的波段
Parameters:
-----------
original_band : np.ndarray
原始波段数据
processed_band : np.ndarray
处理后的波段数据
operation_name : str
操作名称
save_path : str, optional
保存图像路径
"""
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 原始图像
im1 = axes[0].imshow(original_band, cmap='gray')
axes[0].set_title('原始波段')
axes[0].axis('off')
plt.colorbar(im1, ax=axes[0], shrink=0.8)
# 处理后的图像
im2 = axes[1].imshow(processed_band, cmap='gray')
axes[1].set_title(f'{operation_name} 结果')
axes[1].axis('off')
plt.colorbar(im2, ax=axes[1], shrink=0.8)
# 差异图像
diff = processed_band.astype(np.float32) - original_band.astype(np.float32)
im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
axes[2].set_title('差异 (处理后 - 原始)')
axes[2].axis('off')
plt.colorbar(im3, ax=axes[2], shrink=0.8)
plt.suptitle(f'形态学操作: {operation_name}', fontsize=16)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"可视化结果已保存到: {save_path}")
plt.show()
def main():
"""主函数:高光谱图像形态学处理"""
import argparse
# 支持的形态学操作列表
morph_operations = [
'dilation', 'erosion', 'opening', 'closing',
'gradient', 'tophat', 'bottomhat', 'reconstruction'
]
# 支持的结构元素形状
se_shapes = ['disk', 'square', 'rectangle', 'diamond']
parser = argparse.ArgumentParser(description='高光谱图像形态学处理工具')
parser.add_argument('input_file', help='输入高光谱图像文件路径')
parser.add_argument('--format', '-f', default='auto',
choices=['auto', 'bil', 'bip', 'bsq', 'dat'],
help='图像格式 (默认: auto)')
parser.add_argument('--band', '-b', type=int, default=0,
help='要处理的波段索引 (默认: 0)')
parser.add_argument('--bands', '-B', type=str, default=None,
help='多个波段索引,用逗号分隔,如 "0,10,20"')
parser.add_argument('--operation', '-o', default='dilation',
choices=morph_operations,
help=f'形态学操作类型 (默认: dilation)')
parser.add_argument('--se_shape', '-s', default='disk',
choices=se_shapes,
help='结构元素形状 (默认: disk)')
parser.add_argument('--se_size', '-S', type=int, default=3,
help='结构元素大小 (默认: 3)')
parser.add_argument('--output_dir', '-d', default='output',
help='输出目录 (默认: output)')
parser.add_argument('--output_name', '-n', default='morphology_result',
help='输出文件名 (不含扩展名) (默认: morphology_result)')
parser.add_argument('--visualize', '-v', action='store_true',
help='是否生成可视化结果')
args = parser.parse_args()
try:
# 初始化处理器
processor = HyperspectralMorphologyProcessor()
# 确定格式
format_type = None if args.format == 'auto' else args.format
# 加载图像
print(f"加载图像: {args.input_file}")
data, header = processor.load_hyperspectral_image(args.input_file, format_type)
# 确定要处理的波段
if args.bands:
band_indices = [int(b.strip()) for b in args.bands.split(',')]
print(f"处理波段: {band_indices}")
single_band = False
else:
band_indices = [args.band]
print(f"处理波段: {args.band}")
single_band = True
# 应用形态学操作
if single_band:
# 单波段处理
band_data = processor.extract_band(args.band)
print(f"主函数中波段数据形状: {band_data.shape}")
result = processor.apply_morphology_operation(
band_data, args.operation, args.se_shape, args.se_size
)
# 保存结果
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{args.band}")
processor.save_as_envi(result, output_path,
f"{args.operation.capitalize()} 处理结果 - 波段 {args.band}")
# 可视化
if args.visualize:
vis_path = os.path.join(args.output_dir, f"visualization_band{args.band}.png")
processor.visualize_results(band_data, result, args.operation, vis_path)
# 打印统计信息
print(f"\n=== 统计信息 ===")
print(f"原始波段 {args.band}: min={band_data.min():.4f}, max={band_data.max():.4f}, "
f"mean={band_data.mean():.4f}")
print(f"处理后: min={result.min():.4f}, max={result.max():.4f}, "
f"mean={result.mean():.4f}")
else:
# 多波段处理
results = processor.apply_to_multiple_bands(
band_indices, args.operation, args.se_shape, args.se_size
)
# 保存每个波段的结果
for band_idx, result in results.items():
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{band_idx}")
processor.save_as_envi(result, output_path,
f"{args.operation.capitalize()} 处理结果 - 波段 {band_idx}")
# 如果只有少量波段,可以创建组合可视化
if len(band_indices) <= 4 and args.visualize:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for idx, band_idx in enumerate(band_indices[:4]):
original_band = processor.extract_band(band_idx)
processed_band = results[band_idx]
axes[idx].imshow(processed_band, cmap='gray')
axes[idx].set_title(f'波段 {band_idx} - {args.operation}')
axes[idx].axis('off')
plt.suptitle(f'多波段形态学处理: {args.operation}', fontsize=16)
plt.tight_layout()
vis_path = os.path.join(args.output_dir, "multi_band_visualization.png")
plt.savefig(vis_path, dpi=300, bbox_inches='tight')
print(f"多波段可视化结果已保存到: {vis_path}")
plt.show()
print(f"\n✓ 形态学处理完成!")
print(f"输出目录: {args.output_dir}")
return 0
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
# 使用示例函数
def run_example():
"""运行示例"""
# 示例1对单个波段进行膨胀操作
processor = HyperspectralMorphologyProcessor()
# 加载图像(请替换为您的实际文件路径)
input_file = "path/to/your/hyperspectral_image.hdr" # 或 .bil, .bip, .bsq, .dat
data, header = processor.load_hyperspectral_image(input_file, format_type='bip')
# 提取波段例如提取第50个波段
band_idx = 50
band_data = processor.extract_band(band_idx)
# 应用不同的形态学操作
operations = ['dilation', 'erosion', 'opening', 'closing',
'gradient', 'tophat', 'bottomhat', 'reconstruction']
results = {}
for operation in operations:
print(f"\n应用 {operation} 操作...")
try:
result = processor.apply_morphology_operation(
band_data, operation, se_shape='disk', se_size=3
)
results[operation] = result
# 保存结果
output_path = f"output/{operation}_band{band_idx}"
processor.save_as_envi(result, output_path,
f"{operation.capitalize()} 处理结果")
# 可视化
processor.visualize_results(band_data, result, operation,
save_path=f"output/visualization_{operation}.png")
except Exception as e:
print(f"操作 {operation} 失败: {e}")
return results
if __name__ == '__main__':
# 示例1对单个波段进行膨胀操作
processor = HyperspectralMorphologyProcessor()
# 加载图像(请替换为您的实际文件路径)
input_file = r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr" # 或 .bil, .bip, .bsq, .dat
data, header = processor.load_hyperspectral_image(input_file, format_type='bip')
# 提取波段例如提取第50个波段
band_idx = 50
band_data = processor.extract_band(band_idx)
# 应用不同的形态学操作
operations = ['dilation', 'erosion', 'opening', 'closing',
'gradient', 'tophat', 'bottomhat', 'reconstruction']
results = {}
for operation in operations:
print(f"\n应用 {operation} 操作...")
try:
result = processor.apply_morphology_operation(
band_data, operation, se_shape='disk', se_size=3
)
results[operation] = result
# 保存结果
output_path = f"output/{operation}_band{band_idx}"
processor.save_as_envi(result, output_path,
f"{operation.capitalize()} 处理结果")
except Exception as e:
print(f"操作 {operation} 失败: {e}")

View File

@ -0,0 +1,608 @@
import numpy as np
from scipy import signal
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import pandas as pd
import pywt
from copy import deepcopy
import joblib # 用于保存和加载模型
import os
from pathlib import Path
try:
import spectral
SPECTRAL_AVAILABLE = True
except ImportError:
SPECTRAL_AVAILABLE = False
print("警告: spectral库不可用将使用内置方法读取数据")
class HyperspectralPreprocessor:
"""
高光谱数据预处理器
支持的文件格式:
- CSV文件: 需要指定光谱数据的起始列名
- ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (需要spectral库)
"""
def __init__(self):
self.data = None
self.wavelengths = None
self.data_shape = None
self.input_format = None # 'csv' 或 'envi'
def load_data(self, file_path, spectral_start_col=None):
"""
加载高光谱数据文件
参数:
file_path: 文件路径
spectral_start_col: 对于CSV文件光谱数据起始列名
"""
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.csv':
if spectral_start_col is None:
raise ValueError("对于CSV文件必须指定spectral_start_col参数")
self.input_format = 'csv'
return self._load_csv_data(file_path, spectral_start_col)
else:
# ENVI格式
if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']:
self.input_format = 'envi'
return self._load_envi_data(file_path)
else:
raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件并确保已安装spectral库。")
def _load_csv_data(self, file_path, spectral_start_col):
"""加载CSV格式的高光谱数据"""
df = pd.read_csv(file_path)
# 找到光谱起始列的索引
if spectral_start_col not in df.columns:
raise ValueError(f"列名 '{spectral_start_col}' 不存在于CSV文件中")
start_idx = df.columns.get_loc(spectral_start_col)
# 提取光谱数据
spectral_data = df.iloc[:, start_idx:].values
# 提取波长信息(列名)
self.wavelengths = df.columns[start_idx:].astype(float).values
# 保存数据
self.data = spectral_data
self.data_shape = self.data.shape
print(f"CSV数据加载成功: {self.data_shape}")
print(f"波段数量: {len(self.wavelengths)}")
if len(self.wavelengths) > 0:
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
return self.data
def _load_envi_data(self, file_path):
"""加载ENVI格式的高光谱数据"""
original_path = Path(file_path)
current_path = original_path
# 如果是.hdr文件找到对应的数据文件
if current_path.suffix.lower() == '.hdr':
data_file = current_path.with_suffix('')
if not data_file.exists():
for ext in ['.dat', '.bil', '.bsq', '.bip']:
candidate = current_path.with_suffix(ext)
if candidate.exists():
data_file = candidate
break
current_path = data_file
print(f"使用spectral库加载: {current_path}")
# 使用spectral打开图像
img = spectral.open_image(str(file_path))
# 读取所有波段的数据
self.data = img.load()
# 获取波长信息
if hasattr(img, 'metadata') and 'wavelength' in img.metadata:
wavelength_data = img.metadata['wavelength']
try:
if isinstance(wavelength_data, str):
wavelength_data = wavelength_data.strip('{}[]')
if ',' in wavelength_data:
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()])
else:
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()])
elif isinstance(wavelength_data, list):
self.wavelengths = np.array([float(w) for w in wavelength_data])
else:
self.wavelengths = np.array(wavelength_data, dtype=float)
print(f"spectral解析波长: {len(self.wavelengths)} 个波段")
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
except Exception as e:
print(f"波长解析失败,使用默认值: {e}")
self.wavelengths = np.arange(self.data.shape[2])
else:
self.wavelengths = np.arange(self.data.shape[2])
print(f"警告: spectral未找到波长信息使用默认值: 0-{self.data.shape[2]-1}")
self.data_shape = self.data.shape
print(f"spectral数据加载成功: {self.data_shape}")
print(f"数据类型: {self.data.dtype}")
print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]")
return self.data
def save_data(self, output_path, data=None, wavelengths=None):
"""
保存预处理后的数据
参数:
output_path: 输出文件路径
data: 要保存的数据 (如果为None使用self.data)
wavelengths: 波长信息 (如果为None使用self.wavelengths)
"""
if data is None:
data = self.data
if wavelengths is None:
wavelengths = self.wavelengths
output_path = Path(output_path)
suffix = output_path.suffix.lower()
if suffix == '.csv':
self._save_csv_data(output_path, data, wavelengths)
else:
# ENVI格式
if self.input_format == 'envi' or suffix in ['.bil', '.bsq', '.bip', '.dat']:
self._save_envi_data(output_path, data, wavelengths)
else:
raise ValueError(f"不支持的输出格式: {suffix}")
def _save_csv_data(self, output_path, data, wavelengths):
"""保存为CSV格式"""
if wavelengths is not None:
# 创建DataFrame列名使用波长
df = pd.DataFrame(data, columns=[f"{w:.1f}" for w in wavelengths])
else:
df = pd.DataFrame(data)
df.to_csv(output_path, index=False)
print(f"已保存CSV文件: {output_path}")
def _save_envi_data(self, output_path, data, wavelengths):
"""保存为ENVI格式"""
data_file = output_path.with_suffix('.dat')
hdr_file = output_path.with_suffix('.hdr')
# 保存二进制数据
data.astype('float32').tofile(str(data_file))
# 创建ENVI头文件
if len(data.shape) == 1:
lines = 1
samples = data.shape[0]
bands = 1
elif len(data.shape) == 2:
lines = data.shape[0]
samples = data.shape[1]
bands = 1
else:
lines = data.shape[0]
samples = data.shape[1]
bands = data.shape[2]
header_content = f"""ENVI
samples = {samples}
lines = {lines}
bands = {bands}
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bip
byte order = 0
wavelength units = nm
"""
if wavelengths is not None and len(wavelengths) == bands:
wavelength_str = ', '.join([f"{w:.1f}" for w in wavelengths])
header_content += f"wavelength = {{{wavelength_str}}}\n"
with open(hdr_file, 'w', encoding='utf-8') as f:
f.write(header_content)
print(f"已保存ENVI文件: {data_file}{hdr_file}")
def preprocess(self, method, output_path, **kwargs):
"""
执行预处理并保存结果
参数:
method: 预处理方法名 ('MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave')
output_path: 输出文件路径
**kwargs: 传递给预处理方法的参数
"""
if self.data is None:
raise ValueError("请先加载数据")
# 执行预处理
processed_data = self._apply_preprocessing(method, **kwargs)
# 保存结果
self.save_data(output_path, processed_data, self.wavelengths)
return processed_data
def _apply_preprocessing(self, method, **kwargs):
"""应用预处理方法"""
method_funcs = {
'MMS': self._MMS,
'SS': self._SS,
'CT': self._CT,
'SNV': self._SNV,
'MA': self._MA,
'SG': self._SG,
'D1': self._D1,
'D2': self._D2,
'DT': self._DT,
'MSC': self._MSC,
'wave': self._wave
}
if method not in method_funcs:
raise ValueError(f"不支持的预处理方法: {method}")
return method_funcs[method](**kwargs)
# 预处理方法实现
def _MMS(self, **kwargs):
"""最大最小值归一化"""
print("执行最大最小值归一化...")
scaler = MinMaxScaler()
if len(self.data.shape) == 2:
# CSV格式: (samples, bands)
return scaler.fit_transform(self.data)
else:
# 图像格式: (rows, cols, bands) -> 需要reshape
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
normalized = scaler.fit_transform(reshaped)
return normalized.reshape(original_shape)
def _SS(self, save_path=None, **kwargs):
"""标准化"""
print("执行标准化...")
scaler = StandardScaler()
if len(self.data.shape) == 2:
result = scaler.fit_transform(self.data)
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
result = scaler.fit_transform(reshaped)
result = result.reshape(original_shape)
if save_path:
joblib.dump(scaler, save_path)
print(f"Scaler参数已保存到: {save_path}")
return result
def _CT(self, **kwargs):
"""均值中心化"""
print("执行均值中心化...")
if len(self.data.shape) == 2:
# 2D数据: (samples, bands) - 按行计算均值并中心化
mean_vals = np.mean(self.data, axis=1, keepdims=True)
return self.data - mean_vals
else:
# 3D数据: (rows, cols, bands) - 按每个像素的光谱计算均值并中心化
mean_vals = np.mean(self.data, axis=2, keepdims=True)
return self.data - mean_vals
def _SNV(self, **kwargs):
"""标准正态变换"""
print("执行标准正态变换...")
if len(self.data.shape) != 2:
raise ValueError("SNV方法只支持2D数据")
# 计算每行的均值和标准差
data_average = np.mean(self.data, axis=1, keepdims=True)
data_std = np.std(self.data, axis=1, keepdims=True)
# 避免除零错误
data_std = np.where(data_std == 0, 1, data_std)
# 标准化
return (self.data - data_average) / data_std
def _MA(self, WSZ=11, **kwargs):
"""移动平均平滑"""
print(f"执行移动平均平滑 (窗口大小: {WSZ})...")
output_data = deepcopy(self.data)
if len(self.data.shape) == 2:
for i in range(output_data.shape[0]):
out0 = np.convolve(output_data[i], np.ones(WSZ, dtype=int), 'valid') / WSZ
r = np.arange(1, WSZ - 1, 2)
start = np.cumsum(output_data[i, :WSZ - 1])[::2] / r
stop = (np.cumsum(output_data[i, :-WSZ:-1])[::2] / r)[::-1]
output_data[i] = np.concatenate((start, out0, stop))
else:
for i in range(output_data.shape[0]):
for j in range(output_data.shape[1]):
spectrum = output_data[i, j, :]
out0 = np.convolve(spectrum, np.ones(WSZ, dtype=int), 'valid') / WSZ
r = np.arange(1, WSZ - 1, 2)
start = np.cumsum(spectrum[:WSZ - 1])[::2] / r
stop = (np.cumsum(spectrum[:-WSZ:-1])[::2] / r)[::-1]
output_data[i, j, :] = np.concatenate((start, out0, stop))
return output_data
def _SG(self, w=15, p=2, **kwargs):
"""Savitzky-Golay平滑滤波"""
print(f"执行Savitzky-Golay平滑 (窗口: {w}, 阶数: {p})...")
if len(self.data.shape) == 2:
return signal.savgol_filter(self.data, w, p)
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
filtered = signal.savgol_filter(reshaped, w, p)
return filtered.reshape(original_shape)
def _D1(self, **kwargs):
"""一阶导数"""
print("执行一阶导数...")
if len(self.data.shape) == 2:
n, p = self.data.shape
output_data = np.ones((n, p - 1))
for i in range(n):
output_data[i] = np.diff(self.data[i])
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
n, p = reshaped.shape
diff_data = np.ones((n, p - 1))
for i in range(n):
diff_data[i] = np.diff(reshaped[i])
output_data = diff_data.reshape(original_shape[0], original_shape[1], p - 1)
# 更新波长信息(减少一个波段)
if self.wavelengths is not None and len(self.wavelengths) > 1:
self.wavelengths = self.wavelengths[:-1]
return output_data
def _D2(self, **kwargs):
"""二阶导数"""
print("执行二阶导数...")
if len(self.data.shape) == 2:
# 2D数据: (samples, bands)
# 计算二阶导数:对原数据求两次差分
first_diff = np.diff(self.data, axis=1) # 一阶导数
second_diff = np.diff(first_diff, axis=1) # 二阶导数
output_data = second_diff
elif len(self.data.shape) == 3:
# 3D数据: (rows, cols, bands) - 高光谱图像
# 对bands维度进行差分
first_diff = np.diff(self.data, axis=2) # 一阶导数
second_diff = np.diff(first_diff, axis=2) # 二阶导数
output_data = second_diff
else:
raise ValueError("不支持的数据维度")
# 更新波长信息(减少两个波段)
if self.wavelengths is not None and len(self.wavelengths) > 2:
self.wavelengths = self.wavelengths[:-2]
return output_data
def _DT(self, **kwargs):
"""趋势校正"""
print("执行趋势校正...")
output_data = np.array(self.data)
if len(self.data.shape) == 2:
length = output_data.shape[1]
x = np.asarray(range(length), dtype=np.float32)
l = LinearRegression()
for i in range(output_data.shape[0]):
l.fit(x.reshape(-1, 1), output_data[i].reshape(-1, 1))
k = l.coef_
b = l.intercept_
for j in range(output_data.shape[1]):
output_data[i][j] = output_data[i][j] - (j * k + b)
else:
length = output_data.shape[2]
x = np.asarray(range(length), dtype=np.float32)
l = LinearRegression()
for i in range(output_data.shape[0]):
for j in range(output_data.shape[1]):
spectrum = output_data[i, j, :]
l.fit(x.reshape(-1, 1), spectrum.reshape(-1, 1))
k = l.coef_
b = l.intercept_
for k_idx in range(output_data.shape[2]):
output_data[i, j, k_idx] = output_data[i, j, k_idx] - (k_idx * k + b)
return output_data
def _MSC(self, **kwargs):
"""多元散射校正"""
print("执行多元散射校正...")
if len(self.data.shape) == 2:
n, p = self.data.shape
output_data = np.ones((n, p))
mean = np.mean(self.data, axis=0)
for i in range(n):
y = self.data[i, :]
l = LinearRegression()
l.fit(mean.reshape(-1, 1), y.reshape(-1, 1))
k = l.coef_
b = l.intercept_
output_data[i, :] = (y - b) / k
elif len(self.data.shape) == 3:
# 3D数据: (rows, cols, bands) - 高光谱图像
rows, cols, bands = self.data.shape
output_data = np.zeros((rows, cols, bands), dtype=np.float32)
# 计算整个图像的平均光谱
mean_spectrum = np.mean(self.data, axis=(0, 1)) # 对所有像素求平均
# 对每个像素进行MSC校正
for i in range(rows):
for j in range(cols):
y = self.data[i, j, :]
l = LinearRegression()
l.fit(mean_spectrum.reshape(-1, 1), y.reshape(-1, 1))
k = l.coef_
b = l.intercept_
# 避免除零错误
k = max(k, 1e-10) if k == 0 else k
output_data[i, j, :] = (y - b) / k
else:
raise ValueError("不支持的数据维度")
return output_data
def _wave(self, **kwargs):
"""小波变换"""
print("执行小波变换...")
def wave_single(spectrum):
w = pywt.Wavelet('db8')
maxlev = pywt.dwt_max_level(len(spectrum), w.dec_len)
coeffs = pywt.wavedec(spectrum, 'db8', level=maxlev)
threshold = 0.04
for i in range(1, len(coeffs)):
coeffs[i] = pywt.threshold(coeffs[i], threshold * max(coeffs[i]))
return pywt.waverec(coeffs, 'db8')
if len(self.data.shape) == 2:
output_data = None
for i in range(self.data.shape[0]):
processed = wave_single(self.data[i])
if i == 0:
output_data = processed
else:
output_data = np.vstack((output_data, processed))
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
output_data = None
for i in range(reshaped.shape[0]):
processed = wave_single(reshaped[i])
if i == 0:
output_data = processed
else:
output_data = np.vstack((output_data, processed))
output_data = output_data.reshape(original_shape)
return output_data
# ===== 便捷函数 =====
def preprocess_file(input_file, output_file, method, spectral_start_col=None, **kwargs):
"""
便捷函数:预处理单个文件
参数:
input_file: 输入文件路径
output_file: 输出文件路径
method: 预处理方法
spectral_start_col: CSV文件的光谱起始列名
**kwargs: 传递给预处理方法的参数
"""
processor = HyperspectralPreprocessor()
processor.load_data(input_file, spectral_start_col)
return processor.preprocess(method, output_file, **kwargs)
# 保持向后兼容的函数
def MMS(input_spectrum):
"""最大最小值归一化 (向后兼容)"""
return HyperspectralPreprocessor()._MMS(input_spectrum)
def SS(input_spectrum, save_path=None):
"""标准化 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SS(save_path=save_path)
def CT(input_spectrum):
"""均值中心化 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._CT()
def SNV(input_spectrum):
"""标准正态变换 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SNV()
def MA(input_spectrum, WSZ=11):
"""移动平均平滑 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._MA(WSZ=WSZ)
def SG(input_spectrum, w=15, p=2):
"""Savitzky-Golay平滑滤波 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SG(w=w, p=p)
def D1(input_spectrum):
"""一阶导数 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._D1()
def D2(input_spectrum):
"""二阶导数 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._D2()
def DT(input_spectrum):
"""趋势校正 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._DT()
def MSC(input_spectrum):
"""多元散射校正 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._MSC()
def wave(input_spectrum):
"""小波变换 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._wave()
# ===== 使用示例 =====
if __name__ == "__main__":
# 示例1: 处理ENVI格式的高光谱图像
print("=== 示例1: 处理ENVI格式图像 ===")
processor = HyperspectralPreprocessor()
print("\n=== 示例2: 处理CSV格式数据 ===")
try:
# 假设CSV文件的格式第一列是样本ID后面列是从'400.0'开始的光谱数据
processor_csv = HyperspectralPreprocessor()
processor_csv.load_data(r"E:\code\spectronon\single_classsfication\tsst\train.csv", spectral_start_col='374.285004')
# 执行预处理
processor_csv.preprocess('SNV', r'E:\code\spectronon\single_classsfication\tsst\output_snv.csv')
processor_csv.preprocess('MSC', r'E:\code\spectronon\single_classsfication\tsst\output_msc.csv')
except Exception as e:
print(f"CSV示例失败: {e}")

32
requirements.txt Normal file
View File

@ -0,0 +1,32 @@
# 核心科学计算和数据处理
numpy>=1.21.0
pandas>=1.3.0
scipy>=1.7.0
scikit-learn>=1.0.0
# 机器学习和深度学习
xgboost>=1.5.0
lightgbm>=3.3.0
# 可选:深度学习框架(选择一个或多个)
tensorflow>=2.8.0 # TensorFlow/Keras 模型
torch>=1.11.0 # PyTorch 模型
torchvision>=0.12.0 # PyTorch 视觉模型
# 可视化
matplotlib>=3.5.0
seaborn>=0.11.0
# 图像处理
opencv-python>=4.5.0
# 光谱数据处理(可选)
spectral>=0.22.0 # ENVI 文件处理
# 工具库
joblib>=1.1.0
typing-extensions>=4.0.0 # 类型提示
# 开发和测试(可选)
pytest>=6.2.0
jupyter>=1.0.0

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,880 @@
import numpy as np
import pandas as pd
import os
import re
import sys
from pathlib import Path
from typing import Optional, Dict, Any, Union
from dataclasses import dataclass, field
import warnings
try:
import spectral
SPECTRAL_AVAILABLE = True
except ImportError:
SPECTRAL_AVAILABLE = False
print("警告: spectral库不可用将使用内置方法读取数据")
warnings.filterwarnings('ignore')
@dataclass
class DataConfig:
"""数据配置类"""
data_file_path: str = ""
data_format: str = "auto" # 'csv', 'envi', 'auto'
wavelength_column: Optional[int] = None # CSV格式中波长列的索引
@dataclass
class IndexConfig:
"""光谱指数配置类"""
spectral_index_csv: str = r"E:\code\spectronon\spectral_index.csv"
formula_csv: str = r"E:\code\spectronon\famula.csv"
indices_to_calculate: Optional[list] = None # 要计算的指数列表None表示全部
@dataclass
class OutputConfig:
"""输出配置类"""
output_dir: str = "results"
save_individual_indices: bool = True
save_combined_results: bool = True
output_format: str = "csv" # 'csv', 'excel', 'both'
@dataclass
class SpectralIndexConfig:
"""光谱指数计算完整配置类 - 为GUI对接设计的标准化接口"""
data: DataConfig = field(default_factory=DataConfig)
indices: IndexConfig = field(default_factory=IndexConfig)
output: OutputConfig = field(default_factory=OutputConfig)
def __post_init__(self):
"""参数校验和默认值设置"""
self._validate_parameters()
def _validate_parameters(self):
"""参数校验"""
# 数据参数校验
if self.data.data_format not in ['csv', 'envi', 'auto']:
raise ValueError("Data format must be 'csv', 'envi', or 'auto'")
# 输出参数校验
if self.output.output_format not in ['csv', 'excel', 'both']:
raise ValueError("Output format must be 'csv', 'excel', or 'both'")
# 文件路径校验
if not os.path.exists(self.indices.spectral_index_csv):
print(f"Warning: Spectral index CSV file not found: {self.indices.spectral_index_csv}")
if not os.path.exists(self.indices.formula_csv):
print(f"Warning: Formula CSV file not found: {self.indices.formula_csv}")
@classmethod
def create_default(cls) -> 'SpectralIndexConfig':
"""创建默认配置"""
return cls()
@classmethod
def create_quick_analysis(cls, data_file_path: str, indices_to_calculate: Optional[list] = None) -> 'SpectralIndexConfig':
"""创建快速分析配置"""
config = cls()
config.data.data_file_path = data_file_path
if indices_to_calculate:
config.indices.indices_to_calculate = indices_to_calculate
config.output.save_individual_indices = False # 快速分析不保存单个指数
return config
class HyperspectralIndexCalculator:
"""
高光谱指数计算器 - 支持GUI对接的标准化接口
支持多种光谱指数的自动计算
"""
def __init__(self, config: Optional[SpectralIndexConfig] = None,
spectral_index_csv=None, formula_csv=None):
"""
初始化光谱指数计算器
Parameters:
config (SpectralIndexConfig, optional): 配置对象如果为None则使用默认配置
spectral_index_csv (str, optional): 波段定义文件路径(向后兼容)
formula_csv (str, optional): 光谱指数公式文件路径(向后兼容)
"""
# 处理向后兼容性
if config is None and (spectral_index_csv is not None or formula_csv is not None):
# 使用传统参数方式
config = SpectralIndexConfig()
if spectral_index_csv is not None:
config.indices.spectral_index_csv = spectral_index_csv
if formula_csv is not None:
config.indices.formula_csv = formula_csv
self.config = config or SpectralIndexConfig()
self._validate_config()
# 加载波段定义
self.band_info = self.load_band_info(self.config.indices.spectral_index_csv)
# 加载光谱指数公式
self.formulas = self.load_formulas(self.config.indices.formula_csv)
# 存储当前加载的数据
self.data = None
self.wavelengths = None
self.data_shape = None
self.band_mapping = None
def update_config(self, config: SpectralIndexConfig):
"""
更新配置 - 为GUI动态配置预留接口
Parameters:
config (SpectralIndexConfig): 新的配置对象
"""
self.config = config
self._validate_config()
def _validate_config(self):
"""配置校验"""
try:
self.config._validate_parameters()
except ValueError as e:
raise ValueError(f"Configuration validation failed: {e}")
def load_band_info(self, csv_path):
"""加载波段定义信息"""
df = pd.read_csv(csv_path)
band_info = {}
for _, row in df.iterrows():
name = row['name']
band_info[name] = {
'min': float(row['min']),
'center': float(row['center']),
'max': float(row['max'])
}
# 为常用波段名称创建别名映射
if name in ['nir', 'red', 'green', 'blue', 'swir1', 'swir2']:
band_info[name.upper()] = band_info[name]
return band_info
def load_formulas(self, csv_path):
"""加载光谱指数公式"""
df = pd.read_csv(csv_path)
formulas = {}
for _, row in df.iterrows():
index = row['Index']
formulas[index] = {
'name': row['Name'],
'type': eval(row['Type']), # 将字符串列表转换为列表
'equation': row['Equation'],
'bands': eval(row['Bands']) # 将字符串列表转换为列表
}
return formulas
def load_hyperspectral_data(self, file_path):
"""
加载高光谱数据文件
支持格式:
- CSV: 第一行为波长,后续为数据
- ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (使用spectral库)
"""
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.csv':
return self.load_csv_data(file_path)
else:
# 使用spectral库
if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']:
return self.load_with_spectral(file_path)
else:
raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件并确保已安装spectral库。")
def load_csv_data(self, file_path):
"""加载CSV格式的高光谱数据"""
df = pd.read_csv(file_path)
# 假设第一行是波长
self.wavelengths = df.columns.astype(float).values
self.data = df.values
self.data_shape = self.data.shape
print(f"CSV数据加载成功: {self.data_shape}")
print(f"波段数量: {len(self.wavelengths)}")
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
return self.data
def load_with_spectral(self, file_path):
"""
使用spectral库加载高光谱数据
支持ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件
"""
original_path = Path(file_path)
current_path = original_path
# 如果是.hdr文件找到对应的数据文件
if current_path.suffix.lower() == '.hdr':
data_file = current_path.with_suffix('')
if not data_file.exists():
# 尝试其他常见扩展名
for ext in ['.dat', '.bil', '.bsq', '.bip']:
candidate = current_path.with_suffix(ext)
if candidate.exists():
data_file = candidate
break
current_path = data_file
print(f"使用spectral库加载: {current_path}")
# 使用spectral打开图像
img = spectral.open_image(file_path)
# 读取所有波段的数据
# spectral默认返回(行, 列, 波段)的numpy数组
self.data = img.load()
# 获取波长信息
if hasattr(img, 'metadata') and 'wavelength' in img.metadata:
wavelength_data = img.metadata['wavelength']
try:
# 处理波长数据可能需要转换为float数组
if isinstance(wavelength_data, str):
# 如果是字符串,尝试解析
wavelength_data = wavelength_data.strip('{}[]')
if ',' in wavelength_data:
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()])
else:
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()])
elif isinstance(wavelength_data, list):
self.wavelengths = np.array([float(w) for w in wavelength_data])
else:
self.wavelengths = np.array(wavelength_data, dtype=float)
print(f"spectral解析波长: {len(self.wavelengths)} 个波段")
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
except Exception as e:
print(f"波长解析失败,使用默认值: {e}")
self.wavelengths = np.arange(self.data.shape[2])
else:
# 如果没有波长信息,创建默认值
self.wavelengths = np.arange(self.data.shape[2])
print(f"警告: spectral未找到波长信息使用默认值: 0-{self.data.shape[2]-1}")
self.data_shape = self.data.shape
print(f"spectral数据加载成功: {self.data_shape}")
print(f"数据类型: {self.data.dtype}")
print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]")
return self.data
def find_band_index(self, band_name, tolerance=5):
"""
根据波段名称找到对应的波段索引
参数:
band_name: 波段名称 (如 'w550', 'nir', 'red')
tolerance: 容差范围 (nm)
返回:
波段索引 (从0开始)
"""
if band_name not in self.band_info:
# 尝试匹配标准波段名称
std_names = {
'nir': 'w860', 'red': 'w680', 'green': 'w550',
'blue': 'w470', 'swir1': 'w1650', 'swir2': 'w2200',
'tm1': 'w485', 'tm2': 'w569', 'tm3': 'w660',
'tm4': 'w833', 'tm5': 'w1676', 'tm7': 'w2223'
}
if band_name in std_names:
band_name = std_names[band_name]
else:
raise ValueError(f"未定义的波段名称: {band_name}")
target_info = self.band_info[band_name]
target_center = target_info['center']
target_min = target_info['min']
target_max = target_info['max']
# 检查当前数据的波长范围是否包含目标波段
data_min = self.wavelengths[0]
data_max = self.wavelengths[-1]
if target_min < data_min or target_max > data_max:
raise ValueError(
f"波段 {band_name} 超出数据波长范围: "
f"目标波段 [{target_min:.1f}-{target_max:.1f}nm], "
f"数据范围 [{data_min:.1f}-{data_max:.1f}nm]"
)
# 找到最接近的波长
diffs = np.abs(self.wavelengths - target_center)
min_idx = np.argmin(diffs)
if diffs[min_idx] > tolerance:
print(
f"警告: 波段 {band_name} (目标: {target_center}nm) 匹配到 {self.wavelengths[min_idx]:.1f}nm (差异: {diffs[min_idx]:.1f}nm)")
return min_idx
def calculate_index(self, index_name, output_type='image'):
"""
计算指定光谱指数
参数:
index_name: 光谱指数简称 (如 'NDVI', 'EVI')
output_type: 输出类型 ('image''values')
返回:
光谱指数图像或数值
"""
if index_name not in self.formulas:
available = list(self.formulas.keys())
raise ValueError(f"未找到指数: {index_name}。可用指数: {available}")
formula_info = self.formulas[index_name]
equation = formula_info['equation']
required_bands = formula_info['bands']
print(f"计算指数: {formula_info['name']} ({index_name})")
print(f"所需波段: {required_bands}")
print(f"公式: {equation}")
# 检查公式中是否包含临时变量
if ';' in equation:
# 分离主公式和临时变量定义
parts = equation.split(';')
main_eq = parts[0].strip()
temp_vars = parts[1:]
# 创建局部变量字典
local_vars = {}
# 先计算临时变量
for temp_eq in temp_vars:
if '=' in temp_eq:
var_name, expr = temp_eq.split('=', 1)
var_name = var_name.strip()
expr = expr.strip()
# 计算临时变量
temp_value = self.evaluate_expression(expr, required_bands, local_vars)
local_vars[var_name] = temp_value
else:
main_eq = equation.strip()
local_vars = {}
# 计算主公式
index_result = self.evaluate_expression(main_eq, required_bands, local_vars)
# 根据输出类型返回结果
if output_type == 'image':
return index_result
elif output_type == 'values':
# 展平为一维数组
return index_result.flatten()
else:
return index_result
def evaluate_expression(self, expression, required_bands, local_vars):
"""
评估表达式
参数:
expression: 表达式字符串
required_bands: 所需波段列表
local_vars: 局部变量字典
"""
# 提取所有波段变量
pattern = r'\b(w\d+|tm\d+|nir|red|green|blue|swir1|swir2|thermal)\b'
band_vars = re.findall(pattern, expression)
# 创建波段数据字典
band_data = {}
for band_var in set(band_vars):
# 检查是否已在局部变量中
if band_var in local_vars:
continue
# 查找波段索引
band_idx = self.find_band_index(band_var)
# 获取波段数据
if len(self.data_shape) == 2: # CSV格式
band_data[band_var] = self.data[:, band_idx]
else: # 图像格式
band_data[band_var] = self.data[:, :, band_idx]
# 合并局部变量
all_vars = {**band_data, **local_vars}
# 添加数学函数
math_funcs = {
'sqrt': np.sqrt,
'log': np.log,
'log10': np.log10,
'exp': np.exp,
'sin': np.sin,
'cos': np.cos,
'tan': np.tan,
'abs': np.abs,
'pow': np.power
}
# 安全评估表达式
try:
# 替换表达式中的^为**
expression = expression.replace('^', '**')
# 创建安全环境
safe_env = {**all_vars, **math_funcs}
# 添加numpy函数
safe_env['np'] = np
# 执行表达式
result = eval(expression, {"__builtins__": {}}, safe_env)
# 处理可能的异常值
if isinstance(result, np.ndarray):
result = np.nan_to_num(result, nan=0.0, posinf=1.0, neginf=-1.0)
return result
except Exception as e:
raise ValueError(f"表达式评估失败: {expression}\n错误: {str(e)}")
def calculate_multiple_indices(self, index_list, output_format='separate'):
"""
计算多个光谱指数
参数:
index_list: 指数名称列表
output_format: 输出格式 ('separate''stack')
返回:
光谱指数结果
"""
results = {}
for index_name in index_list:
try:
result = self.calculate_index(index_name, output_type='image')
results[index_name] = result
print(f"{index_name}: 计算完成 (形状: {result.shape})")
except Exception as e:
print(f"{index_name}: 计算失败 - {str(e)}")
results[index_name] = None
if output_format == 'stack' and results:
# 将所有指数堆叠成一个多波段图像
valid_results = [r for r in results.values() if r is not None]
if valid_results:
return np.stack(valid_results, axis=-1)
return results
def save_results(self, results, output_prefix='output'):
"""
保存计算结果
参数:
results: 计算结果字典或数组
output_prefix: 输出文件前缀
"""
# 根据输入数据形状判断文件格式
is_csv_format = len(self.data_shape) == 2
if isinstance(results, dict):
if is_csv_format:
# CSV输入: 合并所有指数为单个CSV文件
self.save_merged_csv(results, output_prefix)
else:
# 图像输入: 将多个指数堆叠为多波段dat文件
self.save_multiband_dat(results, output_prefix)
elif isinstance(results, np.ndarray):
# 保存堆叠的结果
output_file = f"{output_prefix}_indices_stack"
self.save_single_result(results, "multiple_indices", output_file)
def save_single_result(self, data, index_name, output_file):
"""保存单个结果"""
# 确定文件格式
if output_file.endswith('.csv'):
# 保存为CSV
if len(data.shape) == 2:
df = pd.DataFrame(data)
else:
# 3D数据展平
flattened = data.reshape(-1, data.shape[-1]) if len(data.shape) == 3 else data.flatten()
df = pd.DataFrame(flattened)
df.to_csv(output_file, index=False)
print(f"已保存: {output_file} (CSV格式)")
else:
# 保存为二进制文件 + 头文件
data_file = output_file + '.dat'
hdr_file = output_file + '.hdr'
# 保存二进制数据
data.astype('float32').tofile(data_file)
# 创建ENVI头文件 - 符合ENVI标准格式
if len(data.shape) == 1:
# 1D数据 (CSV格式)
lines = 1
samples = data.shape[0]
bands = 1
elif len(data.shape) == 2:
# 2D数据 (图像格式)
lines = data.shape[0]
samples = data.shape[1]
bands = 1
else:
# 3D数据 (多波段图像)
lines = data.shape[0]
samples = data.shape[1]
bands = data.shape[2]
header_content = f"""ENVI
samples = {samples}
lines = {lines}
bands = {bands}
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bsq
byte order = 0
wavelength units = Unknown
data ignore value = 0
"""
with open(hdr_file, 'w', encoding='utf-8') as f:
f.write(header_content)
print(f"已保存: {data_file}{hdr_file} (ENVI格式)")
print(f"数据形状: {data.shape}")
def save_merged_csv(self, results_dict, output_prefix):
"""
将多个光谱指数结果合并保存为单个CSV文件
参数:
results_dict: 包含多个指数结果的字典
output_prefix: 输出文件前缀
"""
# 创建输出文件名
output_file = f"{output_prefix}_merged_indices.csv"
# 过滤掉None值的结果
valid_results = {name: data for name, data in results_dict.items() if data is not None}
if not valid_results:
print("没有有效的指数结果可保存")
return
# 获取第一个有效结果的形状作为参考
first_result = next(iter(valid_results.values()))
if len(first_result.shape) != 1:
print("合并CSV仅支持1D数据")
return
# 创建合并的数据框
merged_data = {}
# 添加每个指数的值
for index_name, data in valid_results.items():
merged_data[index_name] = data
# 创建DataFrame并保存
df = pd.DataFrame(merged_data)
df.to_csv(output_file, index=False)
print(f"已合并保存 {len(valid_results)} 个指数到: {output_file}")
print(f"总样本数: {len(df)}")
def save_multiband_dat(self, results_dict, output_prefix):
"""
将多个光谱指数结果保存为多波段的dat文件
参数:
results_dict: 包含多个指数结果的字典
output_prefix: 输出文件前缀
"""
# 过滤掉None值的结果
valid_results = {name: data for name, data in results_dict.items() if data is not None}
if not valid_results:
print("没有有效的指数结果可保存")
return
# 移除单维度(如果存在)
index_names = []
index_data_list = []
for index_name, data in valid_results.items():
index_names.append(index_name)
# 如果数据有单维度(例如形状为(1066, 1148, 1)),则压缩它
if data.ndim == 3 and data.shape[2] == 1:
data = np.squeeze(data, axis=2)
index_data_list.append(data)
# 检查所有数据是否都是二维的
for i, data in enumerate(index_data_list):
if data.ndim != 2:
print(f"警告: 指数 {index_names[i]} 的形状为 {data.shape},不是二维数组")
# 尝试展平为二维
if data.ndim > 2:
data = data.reshape(data.shape[0], data.shape[1])
else:
data = data.reshape(1, -1)
index_data_list[i] = data
# 堆叠为多波段数据 (行, 列, 波段)
stacked_data = np.stack(index_data_list, axis=-1)
# 保存为dat文件
output_file = f"{output_prefix}_indices"
data_file = output_file + '.dat'
hdr_file = output_file + '.hdr'
# 保存二进制数据
stacked_data.astype('float32').tofile(data_file)
# 创建ENVI头文件 - 符合ENVI标准格式
lines = stacked_data.shape[0]
samples = stacked_data.shape[1]
bands = stacked_data.shape[2]
header_content = f"""ENVI
samples = {samples}
lines = {lines}
bands = {bands}
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bip
byte order = 0
wavelength units = Unknown
data ignore value = 0
band names = {{{', '.join(index_names)}}}
"""
with open(hdr_file, 'w', encoding='utf-8') as f:
f.write(header_content)
print(f"已保存多波段数据: {data_file}{hdr_file}")
print(f"包含 {bands} 个指数: {', '.join(index_names)}")
print(f"数据形状: {stacked_data.shape}")
def list_available_indices(self, category=None):
"""
列出所有可用的光谱指数
参数:
category: 按类别过滤 (如 'Vegetation', 'Mineral')
"""
print("\n=== 可用光谱指数 ===")
indices_by_category = {}
for idx, info in self.formulas.items():
categories = info['type']
for cat in categories:
if cat not in indices_by_category:
indices_by_category[cat] = []
indices_by_category[cat].append((idx, info['name']))
# 显示所有类别或特定类别
if category:
if category in indices_by_category:
print(f"\n{category} 指数:")
for idx, name in sorted(indices_by_category[category]):
print(f" {idx:10} - {name}")
else:
print(f"未找到类别: {category}")
else:
for cat in sorted(indices_by_category.keys()):
print(f"\n{cat} ({len(indices_by_category[cat])}个指数):")
for idx, name in sorted(indices_by_category[cat])[:10]: # 只显示前10个
print(f" {idx:10} - {name}")
if len(indices_by_category[cat]) > 10:
print(f" ... 还有 {len(indices_by_category[cat]) - 10}")
def run_analysis_from_config(self) -> Dict[str, Any]:
"""
基于配置对象运行完整分析流程 - 推荐用于GUI对接
Returns:
Dict[str, Any]: 分析结果字典
"""
print("Starting spectral index analysis from configuration...")
# 1. 加载数据
if not self.config.data.data_file_path:
raise ValueError("Data file path must be specified in configuration")
print(f"Loading data from: {self.config.data.data_file_path}")
self.load_hyperspectral_data(self.config.data.data_file_path)
# 2. 确定要计算的指数
if self.config.indices.indices_to_calculate is None:
# 计算所有可用指数
indices_to_calculate = list(self.formulas.keys())
print(f"Calculating all {len(indices_to_calculate)} available indices")
else:
# 计算指定的指数
indices_to_calculate = self.config.indices.indices_to_calculate
invalid_indices = [idx for idx in indices_to_calculate if idx not in self.formulas]
if invalid_indices:
print(f"Warning: The following indices do not exist: {invalid_indices}")
indices_to_calculate = [idx for idx in indices_to_calculate if idx in self.formulas]
print(f"Calculating {len(indices_to_calculate)} specified indices: {indices_to_calculate}")
# 3. 计算指数
results = self.calculate_multiple_indices(indices_to_calculate, output_format='separate')
# 4. 保存结果
if self.config.output.save_combined_results or self.config.output.save_individual_indices:
self.save_results(results, 'config_results')
print("Analysis completed!")
return results
total = len(self.formulas)
print(f"\n总计: {total} 个光谱指数")
def main():
"""主函数 - 命令行接口"""
import argparse
parser = argparse.ArgumentParser(description='高光谱光谱指数计算工具')
parser.add_argument('input_file', help='输入的高光谱文件路径 (CSV, BIL, BSQ, BIP, DAT)')
parser.add_argument('-i', '--indices', nargs='+', help='要计算的指数列表 (如 NDVI EVI)')
parser.add_argument('-a', '--all', action='store_true', help='计算所有植被指数')
parser.add_argument('-c', '--category', help='计算特定类别的所有指数 (Vegetation, Mineral等)')
parser.add_argument('-o', '--output', default='output', help='输出文件前缀')
parser.add_argument('-f', '--format', choices=['separate', 'stack'], default='separate',
help='输出格式: separate(分开) 或 stack(堆叠)')
parser.add_argument('-l', '--list', action='store_true', help='列出所有可用指数')
args = parser.parse_args()
# 初始化计算器
calculator = HyperspectralIndexCalculator()
# 如果只是列出指数
if args.list:
calculator.list_available_indices()
return
# 加载数据
print(f"加载数据: {args.input_file}")
calculator.load_hyperspectral_data(args.input_file)
# 确定要计算的指数
indices_to_calculate = []
if args.indices:
indices_to_calculate = args.indices
elif args.all:
# 计算所有植被指数
for idx, info in calculator.formulas.items():
if 'Vegetation' in info['type']:
indices_to_calculate.append(idx)
elif args.category:
# 计算特定类别的所有指数
for idx, info in calculator.formulas.items():
if args.category in info['type']:
indices_to_calculate.append(idx)
else:
# 默认计算常用植被指数
default_indices = ['NDVI', 'EVI', 'NDWI', 'NDII', 'MSI', 'GNDVI']
indices_to_calculate = [idx for idx in default_indices if idx in calculator.formulas]
if not indices_to_calculate:
print("错误: 未指定要计算的指数")
calculator.list_available_indices()
return
print(f"\n将计算 {len(indices_to_calculate)} 个指数:")
for idx in indices_to_calculate[:10]:
if idx in calculator.formulas:
print(f" - {idx}: {calculator.formulas[idx]['name']}")
if len(indices_to_calculate) > 10:
print(f" ... 还有 {len(indices_to_calculate) - 10}")
# 计算指数
results = calculator.calculate_multiple_indices(
indices_to_calculate,
output_format=args.format
)
# 保存结果
calculator.save_results(results, args.output)
print("\n计算完成!")
# ===== 主要示例 - 展示配置驱动和向后兼容两种方式 =====
if __name__ == "__main__":
print("="*60)
print("Spectral Index Calculator - Configuration-Driven Interface")
print("="*60)
# 方法1配置驱动方式推荐用于GUI对接
print("\n--- Method 1: Configuration-Driven (Recommended for GUI) ---")
# 创建配置对象
config = SpectralIndexConfig.create_quick_analysis(
data_file_path=r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr",
indices_to_calculate=['NDVI', 'EVI', 'NDWI'] # 指定要计算的指数
)
# 可选:自定义配置
config.output.output_dir = "config_results"
config.output.save_individual_indices = False # 只保存合并结果
# 创建计算器并传入配置
calculator = HyperspectralIndexCalculator(config)
# 运行配置驱动的分析
results = calculator.run_analysis_from_config()
print(f"Configuration-driven analysis completed! Calculated {len(results)} indices.")
# 方法2向后兼容方式传统参数传递
print("\n--- Method 2: Backward Compatible (Legacy Direct Usage) ---")
calculator2 = HyperspectralIndexCalculator() # 使用默认配置
# 使用传统方式加载数据和计算
calculator2.load_hyperspectral_data(r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr")
# 计算多个指数
results2 = calculator2.calculate_multiple_indices(['NDVI', 'EVI', 'NDWI'])
print(f"Legacy analysis completed! Calculated {len(results2)} indices.")
# 保存结果
calculator2.save_results(results2, 'legacy_results')
# 方法3命令行方式如果需要
print("\n--- Method 3: Command Line (if needed) ---")
print("Run: python spectral_index.py input_file.dat -i NDVI EVI -o results")
print("\n" + "="*60)
print("Both configuration-driven and legacy methods are supported.")
print("Configuration-driven is recommended for GUI integration.")
print("="*60)

View File

@ -0,0 +1,846 @@
import numpy as np
from scipy import stats
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import spectral as spy
from spectral import envi
import matplotlib.pyplot as plt
import os
from typing import Tuple, List, Dict, Optional, Union
import warnings
import xml.etree.ElementTree as ET
from shapely.geometry import Polygon, Point
import warnings
warnings.filterwarnings('ignore')
class HyperspectralDistanceMetrics:
"""高光谱数据距离度量类"""
@staticmethod
def cosine_similarity(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""余弦相似度"""
dot_product = np.dot(spectrum1, spectrum2)
norm1 = np.linalg.norm(spectrum1)
norm2 = np.linalg.norm(spectrum2)
return dot_product / (norm1 * norm2) if norm1 != 0 and norm2 != 0 else 0
@staticmethod
def euclidean_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""欧氏距离"""
return np.linalg.norm(spectrum1 - spectrum2)
@staticmethod
def information_divergence(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""信息散度 (Kullback-Leibler divergence)"""
# 添加小常数避免除零
eps = 1e-10
spectrum1 = spectrum1 + eps
spectrum2 = spectrum2 + eps
# 归一化为概率分布
p = spectrum1 / np.sum(spectrum1)
q = spectrum2 / np.sum(spectrum2)
# 计算KL散度
kl_pq = np.sum(p * np.log(p / q))
kl_qp = np.sum(q * np.log(q / p))
return (kl_pq + kl_qp) / 2
@staticmethod
def correlation_coefficient(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""相关系数"""
return np.corrcoef(spectrum1, spectrum2)[0, 1]
@staticmethod
def jeffreys_matusita_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""J-M距离 (Jeffreys-Matusita Distance)"""
eps = 1e-10
# 转换为概率分布
p = (spectrum1 + eps) / np.sum(spectrum1 + eps)
q = (spectrum2 + eps) / np.sum(spectrum2 + eps)
# Bhattacharyya系数
bc = np.sum(np.sqrt(p * q))
# J-M距离 = sqrt(2 * (1 - bc))
jm_distance = np.sqrt(2 * (1 - bc))
return jm_distance
@staticmethod
def spectral_angle_mapper(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""光谱角映射器 (Spectral Angle Mapper)"""
cos_sim = HyperspectralDistanceMetrics.cosine_similarity(spectrum1, spectrum2)
# 转换为角度(弧度)
angle = np.arccos(np.clip(cos_sim, -1, 1))
return angle
@staticmethod
def sid_sa_combined(spectrum1: np.ndarray, spectrum2: np.ndarray,
alpha: float = 0.5) -> float:
"""SID_SA结合法 (Spectral Information Divergence + Spectral Angle Mapper)"""
sid = HyperspectralDistanceMetrics.information_divergence(spectrum1, spectrum2)
sam = HyperspectralDistanceMetrics.spectral_angle_mapper(spectrum1, spectrum2)
# 标准化并结合
sid_norm = sid / (1 + sid) # 归一化到[0,1)
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
return alpha * sid_norm + (1 - alpha) * sam_norm
# ===== 向量化距离计算函数 =====
@staticmethod
def vectorized_cosine_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化余弦距离计算"""
# X: (n_samples, n_features), centers: (n_clusters, n_features)
# 返回: (n_samples, n_clusters)
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
# 避免除零
X_norm = np.where(X_norm == 0, 1, X_norm)
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
# 计算余弦相似度
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
# 转换为距离 (1 - 相似度)
return 1 - np.clip(similarity, -1, 1)
@staticmethod
def vectorized_euclidean_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化欧氏距离计算"""
# 使用广播计算欧氏距离
# X: (n_samples, n_features), centers: (n_clusters, n_features)
# 返回: (n_samples, n_clusters)
diff = X[:, np.newaxis, :] - centers[np.newaxis, :, :]
return np.linalg.norm(diff, axis=2)
@staticmethod
def vectorized_correlation_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化相关系数距离计算"""
# 计算相关系数矩阵
corr_matrix = np.corrcoef(X.T, centers.T)[:len(X), len(X):]
# 转换为距离 (1 - |相关系数|)
return 1 - np.abs(corr_matrix)
@staticmethod
def vectorized_information_divergence(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化信息散度计算"""
eps = 1e-10
X_norm = X + eps
centers_norm = centers + eps
# 归一化为概率分布
X_prob = X_norm / np.sum(X_norm, axis=1, keepdims=True)
centers_prob = centers_norm / np.sum(centers_norm, axis=1, keepdims=True)
# 计算KL散度
# 使用广播: X_prob[:, np.newaxis, :] - centers_prob[np.newaxis, :, :]
kl_pq = np.sum(X_prob[:, np.newaxis, :] * np.log(X_prob[:, np.newaxis, :] / centers_prob[np.newaxis, :, :]), axis=2)
kl_qp = np.sum(centers_prob[np.newaxis, :, :] * np.log(centers_prob[np.newaxis, :, :] / X_prob[:, np.newaxis, :]), axis=2)
return (kl_pq + kl_qp) / 2
@staticmethod
def vectorized_jeffreys_matusita_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化J-M距离计算"""
eps = 1e-10
# 转换为概率分布
X_prob = (X + eps) / np.sum(X + eps, axis=1, keepdims=True)
centers_prob = (centers + eps) / np.sum(centers + eps, axis=1, keepdims=True)
# 计算Bhattacharyya系数
bc = np.sum(np.sqrt(X_prob[:, np.newaxis, :] * centers_prob[np.newaxis, :, :]), axis=2)
# J-M距离
return np.sqrt(2 * (1 - bc))
@staticmethod
def vectorized_sid_sa_combined(X: np.ndarray, centers: np.ndarray, alpha: float = 0.5) -> np.ndarray:
"""向量化SID_SA结合距离计算"""
# 计算SID距离
sid = HyperspectralDistanceMetrics.vectorized_information_divergence(X, centers)
# 计算SAM距离
sam = HyperspectralDistanceMetrics.vectorized_spectral_angle(X, centers)
# 归一化
sid_norm = sid / (1 + sid) # 归一化到[0,1)
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
# 结合
return alpha * sid_norm + (1 - alpha) * sam_norm
@staticmethod
def vectorized_spectral_angle(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化光谱角距离计算"""
# 计算余弦相似度
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
X_norm = np.where(X_norm == 0, 1, X_norm)
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
similarity = np.clip(similarity, -1, 1)
# 转换为角度
return np.arccos(similarity)
class HyperspectralClassification:
"""高光谱图像监督分类器"""
def __init__(self, random_state: int = 42, distance_params: Dict[str, Dict] = None):
self.random_state = random_state
self.distance_metrics = HyperspectralDistanceMetrics()
self.reference_spectra_ = None
self.labels_ = None
self.class_names_ = None
# 默认距离度量参数
self.default_distance_params = {
'cosine': {},
'euclidean': {},
'correlation': {},
'information_divergence': {},
'sam': {},
'jm_distance': {'alpha': 0.5}, # Bhattacharyya系数权重
'sid_sa': {'alpha': 0.5} # SID和SAM的权重
}
# 更新用户提供的参数
self.distance_params = self.default_distance_params.copy()
if distance_params:
self.distance_params.update(distance_params)
# 更新向量化距离函数,添加新的方法
self.vectorized_distance_functions = {
'cosine': self.distance_metrics.vectorized_cosine_distance,
'euclidean': self.distance_metrics.vectorized_euclidean_distance,
'correlation': self.distance_metrics.vectorized_correlation_distance,
'information_divergence': self.distance_metrics.vectorized_information_divergence,
'sam': self.distance_metrics.vectorized_spectral_angle,
'jm_distance': self.distance_metrics.vectorized_jeffreys_matusita_distance, # 新增
'sid_sa': self.distance_metrics.vectorized_sid_sa_combined # 新增
}
def fit_predict_with_references(self, X: np.ndarray,
reference_spectra: Dict[str, np.ndarray],
method: str = 'euclidean') -> np.ndarray:
"""
使用参考光谱进行监督分类
Parameters:
X: 高光谱数据,形状为 (n_samples, n_bands) 或 (rows, cols, n_bands)
reference_spectra: 参考光谱字典,键为类别名,值为光谱 (bands,)
method: 距离度量方法
Returns:
分类结果形状与X相同
"""
if method not in self.vectorized_distance_functions:
raise ValueError(f"不支持的距离度量方法: {method}")
if not reference_spectra:
raise ValueError("必须提供参考光谱")
print(f"DEBUG fit_predict_with_references: X.shape = {X.shape}")
print(f"DEBUG fit_predict_with_references: reference_spectra keys = {list(reference_spectra.keys())}")
for name, spectrum in reference_spectra.items():
print(f"DEBUG fit_predict_with_references: {name}.shape = {spectrum.shape}")
# 处理3D输入 (图像数据)
if X.ndim == 3:
rows, cols, bands = X.shape
X_reshaped = X.reshape(-1, bands)
# 移除NaN和无穷大值
valid_mask = np.isfinite(X_reshaped).all(axis=1)
X_valid = X_reshaped[valid_mask]
else:
X_reshaped = X
X_valid = X_reshaped
rows, cols = None, None
if len(X_valid) == 0:
raise ValueError("没有有效的光谱数据")
# 转换为参考光谱数组 - 确保是二维数组 (n_classes, n_bands)
class_names = list(reference_spectra.keys())
# 首先检查并确保每个参考光谱都是一维的
cleaned_reference_spectra = []
for name in class_names:
spectrum = reference_spectra[name]
if spectrum.ndim > 1:
spectrum = spectrum.flatten()
cleaned_reference_spectra.append(spectrum)
reference_array = np.array(cleaned_reference_spectra) # 现在应该是 (n_classes, n_bands)
print(f"DEBUG fit_predict_with_references: reference_array.shape = {reference_array.shape}")
# 进行分类
labels = self._classify_pixels(X_valid, reference_array, method)
# 如果是3D输入需要重塑回图像形状
if X.ndim == 3:
result = np.full((rows * cols,), -1, dtype=int)
result[valid_mask] = labels
result = result.reshape(rows, cols)
else:
result = labels
self.reference_spectra_ = reference_spectra
self.class_names_ = class_names
self.labels_ = result
return result
def _classify_pixels(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
"""
使用参考光谱对像素进行分类
Parameters:
X: 像素光谱数据 (n_samples, n_bands)
reference_spectra: 参考光谱 (n_classes, n_bands)
method: 距离度量方法
Returns:
分类标签 (n_samples,)
"""
print(f"DEBUG: X.shape = {X.shape}")
print(f"DEBUG: reference_spectra.shape = {reference_spectra.shape}")
print(f"DEBUG: method = {method}")
# 检查是否有向量化距离函数
if method not in self.vectorized_distance_functions or self.vectorized_distance_functions[method] is None:
raise ValueError(f"不支持向量化计算的距离度量方法: {method}")
# 使用向量化距离计算
distances = self.vectorized_distance_functions[method](X, reference_spectra)
# 为每个像素分配最近的参考光谱类别
labels = np.argmin(distances, axis=1)
return labels
def _classify_with_complex_distance(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
"""
使用复杂距离度量进行分类(非向量化)
"""
print(f"使用 {method} 距离度量进行分类...")
n_samples = X.shape[0]
n_classes = reference_spectra.shape[0]
labels = np.zeros(n_samples, dtype=int)
# 为每个像素计算与所有参考光谱的距离
for i in range(n_samples):
pixel = X[i]
min_distance = float('inf')
best_class = 0
for j in range(n_classes):
reference = reference_spectra[j]
if method == 'jm_distance':
distance = self.distance_metrics.jeffreys_matusita_distance(pixel, reference)
elif method == 'sid_sa':
alpha = self.distance_params[method].get('alpha', 0.5)
distance = self.distance_metrics.sid_sa_combined(pixel, reference, alpha)
else:
# 默认使用欧氏距离
distance = self.distance_metrics.euclidean_distance(pixel, reference)
if distance < min_distance:
min_distance = distance
best_class = j
labels[i] = best_class
return labels
def fit_predict_all_methods(self, X: np.ndarray, reference_spectra: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""使用所有距离度量方法进行分类"""
results = {}
methods = list(self.vectorized_distance_functions.keys())
print("开始使用不同距离度量进行分类...")
for method in methods:
print(f"正在使用 {method} 距离度量...")
try:
results[method] = self.fit_predict_with_references(X, reference_spectra, method)
print(f"{method} 分类完成")
except Exception as e:
print(f"{method} 分类失败: {e}")
results[method] = None
return results
class HyperspectralImageProcessor:
"""高光谱图像处理类"""
def __init__(self):
self.data = None
self.header = None
self.wavelengths = None
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
try:
# 读取ENVI文件
img = envi.open(hdr_path)
data = img.load()
header = dict(img.metadata)
# 提取波长信息
if 'wavelength' in header:
wavelengths = np.array([float(w) for w in header['wavelength']])
else:
wavelengths = None
self.data = data
self.header = header
self.wavelengths = wavelengths
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
if wavelengths is not None:
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def parse_roi_xml(self, xml_path: str) -> Dict[str, List[Tuple[float, float]]]:
"""
解析ENVI ROI XML文件提取每个区域的坐标
Parameters:
xml_path: XML文件路径
Returns:
字典键为ROI名称值为坐标列表 [(x1,y1), (x2,y2), ...]
"""
try:
tree = ET.parse(xml_path)
root = tree.getroot()
roi_coordinates = {}
for region in root.findall('Region'):
name = region.get('name')
coordinates_text = None
# 查找Coordinates元素
coord_elem = region.find('.//Coordinates')
if coord_elem is not None and coord_elem.text:
coordinates_text = coord_elem.text.strip()
if coordinates_text:
# 解析坐标字符串
coords = []
parts = coordinates_text.split()
# 坐标是成对的 (x y x y ...)
for i in range(0, len(parts), 2):
if i + 1 < len(parts):
x = float(parts[i])
y = float(parts[i + 1])
coords.append((x, y))
roi_coordinates[name] = coords
return roi_coordinates
except Exception as e:
raise IOError(f"解析XML文件失败: {e}")
def extract_reference_spectra(self, data: np.ndarray,
roi_coordinates: Dict[str, List[Tuple[float, float]]]) -> Dict[str, np.ndarray]:
"""
从图像中提取参考光谱
Parameters:
data: 高光谱图像数据 (rows, cols, bands)
roi_coordinates: ROI坐标字典
Returns:
字典键为ROI名称值为平均光谱 (bands,)
"""
reference_spectra = {}
rows, cols, bands = data.shape
for roi_name, coords in roi_coordinates.items():
# 创建多边形
polygon = Polygon(coords)
# 找到多边形内的所有像素
pixels_in_roi = []
for i in range(rows):
for j in range(cols):
point = Point(j, i) # 注意ENVI坐标系中x是列y是行
if polygon.contains(point):
pixels_in_roi.append((i, j))
if len(pixels_in_roi) == 0:
print(f"警告: ROI '{roi_name}' 中没有找到像素")
continue
# 提取像素光谱
spectra = []
for i, j in pixels_in_roi:
# 修正:确保提取的是一维光谱
spectrum = data[i, j, :]
# 确保是一维数组
if spectrum.ndim > 1:
spectrum = spectrum.flatten()
if np.isfinite(spectrum).all() and spectrum.size == bands: # 只使用有效光谱
spectra.append(spectrum)
if len(spectra) == 0:
print(f"警告: ROI '{roi_name}' 中没有有效的光谱数据")
continue
# 计算平均光谱 - 确保结果是一维
avg_spectrum = np.mean(spectra, axis=0)
if avg_spectrum.ndim > 1:
avg_spectrum = avg_spectrum.flatten()
reference_spectra[roi_name] = avg_spectrum
print(f"ROI '{roi_name}': 找到 {len(spectra)} 个有效像素, 光谱维度 {avg_spectrum.shape}")
return reference_spectra
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
"""加载ENVI格式高光谱图像"""
try:
# 读取ENVI文件
img = envi.open(hdr_path)
data = img.load()
header = dict(img.metadata)
# 提取波长信息
if 'wavelength' in header:
wavelengths = np.array([float(w) for w in header['wavelength']])
else:
wavelengths = None
self.data = data
self.header = header
self.wavelengths = wavelengths
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
if wavelengths is not None:
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def save_single_band_image(self, data: np.ndarray, output_path: str,
method_name: str = "", original_header: Dict = None) -> None:
"""保存单波段聚类结果图像为dat和hdr文件"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 确保输出文件扩展名为.dat
if not output_path.lower().endswith('.dat'):
output_path = output_path.rsplit('.', 1)[0] + '.dat'
# 将数据转换为合适的格式聚类标签通常使用int32
if data.dtype != np.int32:
data_to_save = data.astype(np.int32)
else:
data_to_save = data
# 保存为二进制dat文件
data_to_save.tofile(output_path)
# 创建对应的hdr文件
hdr_path = output_path.rsplit('.', 1)[0] + '.hdr'
self._create_cluster_hdr_file(hdr_path, data_to_save.shape, method_name, original_header)
print(f"聚类结果已保存到:")
print(f" 数据文件: {output_path}")
print(f" 头文件: {hdr_path}")
print(f"数据类型: {data_to_save.dtype}, 形状: {data_to_save.shape}")
except Exception as e:
raise IOError(f"保存图像失败: {e}")
def _create_cluster_hdr_file(self, hdr_path: str, data_shape: Tuple,
method_name: str, original_header: Dict = None) -> None:
"""创建聚类结果的ENVI头文件"""
try:
# 从原始头文件获取基本信息
if original_header is not None:
# 复制原始头文件的关键信息
samples = original_header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
lines = original_header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
bands = 1 # 聚类结果是单波段
interleave = 'bip'
data_type = 3 # ENVI数据类型: 3=int32, 4=float32
byte_order = original_header.get('byte order', 0)
wavelength_units = original_header.get('wavelength units', 'nm')
else:
# 默认值适用于2D聚类结果
if len(data_shape) == 2:
lines, samples = data_shape
else:
lines = data_shape[0]
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
bands = 1
interleave = 'bsq'
data_type = 3 # int32
byte_order = 0
wavelength_units = 'Unknown'
# 写入hdr文件
with open(hdr_path, 'w') as f:
f.write("ENVI\n")
f.write("description = {\n")
f.write(f" 聚类结果 - {method_name}\n")
f.write(" 单波段聚类标签图像\n")
f.write("}\n")
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write(f"header offset = 0\n")
f.write(f"file type = ENVI Standard\n")
f.write(f"data type = {data_type}\n")
f.write(f"interleave = {interleave}\n")
f.write(f"byte order = {byte_order}\n")
# 如果有波长信息,添加虚拟波长
if original_header and 'wavelength' in original_header:
f.write("wavelength = {\n")
f.write(" 类别标签\n")
f.write("}\n")
f.write(f"wavelength units = {wavelength_units}\n")
# 添加类别信息注释
f.write("band names = {\n")
f.write(f" 聚类结果_{method_name}\n")
f.write("}\n")
print(f"成功创建头文件: {hdr_path}")
except Exception as e:
print(f"创建头文件失败: {e}")
def visualize_clusters(self, cluster_results: Dict[str, np.ndarray],
save_path: Optional[str] = None) -> None:
"""可视化聚类结果"""
n_methods = len(cluster_results)
fig, axes = plt.subplots(2, (n_methods + 1) // 2, figsize=(15, 10))
axes = axes.flatten()
for i, (method, result) in enumerate(cluster_results.items()):
if result is not None and i < len(axes):
im = axes[i].imshow(result, cmap='tab10')
axes[i].set_title(f'{method.replace("_", " ").title()}')
axes[i].axis('off')
plt.colorbar(im, ax=axes[i], shrink=0.8)
# 隐藏多余的子图
for j in range(i + 1, len(axes)):
axes[j].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"可视化结果已保存到: {save_path}")
plt.show()
def main():
"""主函数:处理高光谱图像监督分类"""
import argparse
parser = argparse.ArgumentParser(description='高光谱图像监督分类分析')
parser.add_argument('input_file', help='输入ENVI格式高光谱图像文件 (.hdr)')
parser.add_argument('xml_file', help='ENVI ROI XML文件路径')
parser.add_argument('--output_dir', '-o', default='output',
help='输出目录 (默认: output)')
parser.add_argument('--method', '-m', default='all',
choices=['all', 'euclidean', 'cosine', 'correlation',
'information_divergence', 'jm_distance', 'sid_sa'],
help='分类距离度量方法 (默认: all使用所有方法)')
parser.add_argument('--distance_params', '-p', type=str, default=None,
help='距离度量方法的超参数JSON格式字符串例如: {"jm_distance": {"alpha": 0.7}}')
parser.add_argument('--visualize', '-v', action='store_true',
help='是否生成可视化结果')
args = parser.parse_args()
# 解析距离参数
distance_params = None
if args.distance_params:
import json
try:
distance_params = json.loads(args.distance_params)
except json.JSONDecodeError:
print(f"警告: 无法解析距离参数 '{args.distance_params}',使用默认参数")
distance_params = None
try:
# 调用分类函数
return run_hsi_classification(
input_file=args.input_file,
xml_file=args.xml_file,
output_dir=args.output_dir,
method=args.method,
distance_params=distance_params,
visualize=args.visualize
)
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
def run_hsi_classification(input_file, xml_file, output_dir='output', method='all',
distance_params=None, visualize=False):
"""
执行高光谱图像监督分类分析
参数:
input_file: 输入ENVI格式高光谱图像文件 (.hdr)
xml_file: ENVI ROI XML文件路径
output_dir: 输出目录 (默认: output)
method: 分类距离度量方法 ('all' 或具体方法名,默认: 'all')
distance_params: 距离度量方法的超参数字典 (默认: None)
visualize: 是否生成可视化结果 (默认: False)
返回:
成功返回0失败返回1
"""
try:
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 初始化处理器
processor = HyperspectralImageProcessor()
# 加载图像
print(f"加载高光谱图像: {input_file}")
data, header = processor.load_image(input_file)
# 解析XML文件并提取参考光谱
print(f"\n解析ROI XML文件: {xml_file}")
roi_coordinates = processor.parse_roi_xml(xml_file)
if not roi_coordinates:
raise ValueError("XML文件中没有找到有效的ROI区域")
print(f"找到 {len(roi_coordinates)} 个ROI区域")
for roi_name, coords in roi_coordinates.items():
print(f" {roi_name}: {len(coords)} 个顶点")
# 从图像中提取参考光谱
print(f"\n图像数据形状: {data.shape}")
print("提取参考光谱...")
reference_spectra = processor.extract_reference_spectra(data, roi_coordinates)
if not reference_spectra:
raise ValueError("无法从ROI区域提取有效的参考光谱")
print(f"成功提取 {len(reference_spectra)} 个参考光谱:")
for roi_name, spectrum in reference_spectra.items():
print(f" {roi_name}: 形状 {spectrum.shape}, 数据类型 {spectrum.dtype}")
# 初始化分类器
classifier = HyperspectralClassification(distance_params=distance_params)
# 根据指定方法进行分类
print(f"\n开始分类分析 (类别数: {len(reference_spectra)}, 方法: {method})...")
if method == 'all':
results = classifier.fit_predict_all_methods(data, reference_spectra)
else:
# 使用指定方法
result = classifier.fit_predict_with_references(data, reference_spectra, method)
results = {method: result}
# 保存结果
print("\n保存分类结果...")
saved_files = []
for method, result in results.items():
if result is not None:
output_file = os.path.join(output_dir, f'classification_{method}.dat')
processor.save_single_band_image(result, output_file, method, header)
saved_files.append((method, output_file))
# 输出统计信息
print("\n=== 分类统计信息 ===")
for method, result in results.items():
if result is not None:
unique_labels = np.unique(result[result >= 0])
n_classes_found = len(unique_labels)
print(f"\n{method}: 识别 {n_classes_found} 个类别")
# 显示每个类别的像素数量
for class_idx, class_name in enumerate(classifier.class_names_):
pixel_count = np.sum(result == class_idx)
print(f" {class_name}: {pixel_count} 个像素")
print("\n✓ 分类分析完成!")
print(f"输出目录: {output_dir}")
print(f"生成文件: {len(saved_files)}")
return 0
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
# 直接调用示例
if __name__ == '__main__':
# 示例1使用所有分类方法
result = run_hsi_classification(
input_file=r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr", # 你的高光谱图像文件
xml_file=r"E:\code\spectronon\single_classsfication\data\roi.xml", # ROI XML文件
output_dir=r'E:\code\spectronon\single_classsfication\tsst', # 输出目录
method='all', # 使用SAM方法
distance_params=None, # 使用默认参数
visualize=False # 是否可视化
)
# 示例2指定特定分类方法和超参数
# result = run_hsi_classification(
# input_file='your_image.hdr',
# xml_file='roi_regions.xml',
# method='jm_distance', # 只使用JM距离
# distance_params={'jm_distance': {'alpha': 0.7}},
# output_dir='output_jm'
# )
# 示例3使用所有方法进行比较
# result = run_hsi_classification(
# input_file='data/image.hdr',
# xml_file='data/roi.xml',
# method='all' # 使用所有距离度量方法
# )
# 根据返回值判断是否成功
if result == 0:
print("程序执行成功!")
else:
print("程序执行失败!")