第一次提交
This commit is contained in:
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
35
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
35
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,35 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="22">
|
||||
<item index="0" class="java.lang.String" itemvalue="catboost" />
|
||||
<item index="1" class="java.lang.String" itemvalue="opencv-python" />
|
||||
<item index="2" class="java.lang.String" itemvalue="xgboost" />
|
||||
<item index="3" class="java.lang.String" itemvalue="opencv-python-headless" />
|
||||
<item index="4" class="java.lang.String" itemvalue="zstandard" />
|
||||
<item index="5" class="java.lang.String" itemvalue="scikit-image" />
|
||||
<item index="6" class="java.lang.String" itemvalue="scipy" />
|
||||
<item index="7" class="java.lang.String" itemvalue="tensorflow_gpu" />
|
||||
<item index="8" class="java.lang.String" itemvalue="h5py" />
|
||||
<item index="9" class="java.lang.String" itemvalue="matplotlib" />
|
||||
<item index="10" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="11" class="java.lang.String" itemvalue="opencv_python" />
|
||||
<item index="12" class="java.lang.String" itemvalue="Pillow" />
|
||||
<item index="13" class="java.lang.String" itemvalue="tensorflow" />
|
||||
<item index="14" class="java.lang.String" itemvalue="jupyter" />
|
||||
<item index="15" class="java.lang.String" itemvalue="ipykernel" />
|
||||
<item index="16" class="java.lang.String" itemvalue="pandas" />
|
||||
<item index="17" class="java.lang.String" itemvalue="Werkzeug" />
|
||||
<item index="18" class="java.lang.String" itemvalue="cellpose" />
|
||||
<item index="19" class="java.lang.String" itemvalue="torchvision" />
|
||||
<item index="20" class="java.lang.String" itemvalue="Flask" />
|
||||
<item index="21" class="java.lang.String" itemvalue="fiona" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
6
.idea/misc.xml
generated
Normal file
6
.idea/misc.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="insect" />
|
||||
</component>
|
||||
</project>
|
||||
7
.idea/single_classsfication.iml
generated
Normal file
7
.idea/single_classsfication.iml
generated
Normal file
@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module version="4">
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
209
README.md
Normal file
209
README.md
Normal file
@ -0,0 +1,209 @@
|
||||
# 高光谱图像分类与分析工具包
|
||||
|
||||
[](https://www.python.org/)
|
||||
[](LICENSE)
|
||||
|
||||
一个全面的高光谱图像处理、分类和光谱分析工具包,具有GUI就绪架构。
|
||||
|
||||
## 🌟 功能特性
|
||||
|
||||
### 🔬 光谱分析
|
||||
- **光谱指数计算**: 计算各种植被、矿物和水体指数
|
||||
- **自定义公式支持**: 定义和计算自定义光谱指数
|
||||
- **多格式支持**: 处理CSV、ENVI(.hdr/.dat)和其他光谱数据格式
|
||||
|
||||
### 🤖 机器学习模型
|
||||
- **传统模型**: 线性回归、岭回归、LASSO、弹性网络、贝叶斯岭回归
|
||||
- **集成方法**: 随机森林、梯度提升、XGBoost、LightGBM、AdaBoost
|
||||
- **神经网络**: MLP、LSTM、GRU(支持TensorFlow/PyTorch)
|
||||
- **专业模型**: 支持向量回归、高斯过程、KNN
|
||||
|
||||
### 🖼️ 图像处理
|
||||
- **空间滤波**: 均值、中值、高斯和双边滤波
|
||||
- **多波段支持**: 处理单个波段或整个高光谱数据立方体
|
||||
- **格式保持**: 保持原始数据类型和数值范围
|
||||
|
||||
### 🎯 GUI就绪架构
|
||||
- **配置驱动**: 集中化参数管理和验证
|
||||
- **模块化设计**: 业务逻辑与参数输入完全解耦
|
||||
- **错误处理**: 全面的验证和用户友好的错误消息
|
||||
|
||||
## 📋 系统要求
|
||||
|
||||
### 系统配置
|
||||
- Python 3.8 或更高版本
|
||||
- 4GB+ RAM(推荐大型数据集使用8GB+)
|
||||
- GPU 支持(可选,用于深度学习模型)
|
||||
|
||||
### 依赖包
|
||||
|
||||
核心依赖包(完整列表见 `requirements.txt`):
|
||||
```
|
||||
numpy>=1.21.0
|
||||
pandas>=1.3.0
|
||||
scikit-learn>=1.0.0
|
||||
matplotlib>=3.5.0
|
||||
xgboost>=1.5.0
|
||||
lightgbm>=3.3.0
|
||||
```
|
||||
|
||||
可选依赖包:
|
||||
- `tensorflow` 或 `torch` 用于深度学习模型
|
||||
- `spectral` 用于ENVI文件处理
|
||||
- `opencv-python` 用于图像滤波
|
||||
|
||||
## 🚀 安装指南
|
||||
|
||||
### 1. 克隆仓库
|
||||
```bash
|
||||
git clone http://git.iris-rs.cn/zhanghuilai/HSI.git
|
||||
cd HSI
|
||||
```
|
||||
|
||||
### 2. 创建虚拟环境(推荐)
|
||||
```bash
|
||||
python -m venv hyperspectral_env
|
||||
source hyperspectral_env/bin/activate # Windows: hyperspectral_env\Scripts\activate
|
||||
```
|
||||
|
||||
### 3. 安装依赖包
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 4. 根据需要安装可选依赖包
|
||||
```bash
|
||||
# 深度学习支持
|
||||
pip install tensorflow # 或 torch
|
||||
|
||||
# ENVI文件处理
|
||||
pip install spectral
|
||||
|
||||
# 开发环境
|
||||
pip install pytest jupyter
|
||||
```
|
||||
|
||||
## 📖 使用方法
|
||||
|
||||
### 基础用法
|
||||
|
||||
```python
|
||||
from rgression_method.regression import RegressionAnalyzer, RegressionConfig
|
||||
from spectral_index import HyperspectralIndexCalculator, SpectralIndexConfig
|
||||
|
||||
# 1. 回归分析配置驱动
|
||||
config = RegressionConfig.create_default(
|
||||
csv_path="your_data.csv",
|
||||
label_column="target"
|
||||
)
|
||||
config.models.model_names = ['linear', 'xgboost', 'lightgbm']
|
||||
|
||||
analyzer = RegressionAnalyzer(config)
|
||||
results = analyzer.run_analysis_from_config()
|
||||
|
||||
# 2. 光谱指数计算
|
||||
index_config = SpectralIndexConfig.create_default()
|
||||
calculator = HyperspectralIndexCalculator(index_config)
|
||||
results = calculator.run_analysis_from_config("hyperspectral_data.hdr")
|
||||
```
|
||||
|
||||
### 命令行使用
|
||||
|
||||
```bash
|
||||
# 光谱指数计算
|
||||
python spectral_index.py input.hdr -i NDVI EVI -o results
|
||||
|
||||
# 图像滤波
|
||||
python fliter/Smooth_filter.py input.hdr -f mean -k 3 -b 20 -o filtered_output
|
||||
```
|
||||
|
||||
### 高级配置
|
||||
|
||||
```python
|
||||
# 自定义回归配置
|
||||
config = RegressionConfig()
|
||||
config.data.csv_path = "data.csv"
|
||||
config.data.label_column = "chlorophyll"
|
||||
config.data.spectrum_columns = "10:200" # 波长范围
|
||||
config.models.tune_hyperparams = True
|
||||
config.models.model_names = ['ridge', 'lasso', 'xgboost']
|
||||
config.output.save_models = True
|
||||
|
||||
analyzer = RegressionAnalyzer(config)
|
||||
results = analyzer.run_analysis_from_config()
|
||||
```
|
||||
|
||||
## 📁 项目结构
|
||||
|
||||
```
|
||||
hyperspectral-toolkit/
|
||||
├── rgression_method/
|
||||
│ ├── regression.py # 主要回归分析模块
|
||||
│ └── __init__.py
|
||||
├── fliter/
|
||||
│ ├── Smooth_filter.py # 图像滤波模块
|
||||
│ └── morphological_fliter.py # 形态学滤波
|
||||
├── spectral_index.py # 光谱指数计算器
|
||||
├── classfication.py # 分类工具
|
||||
├── cluster.py # 聚类算法
|
||||
├── requirements.txt # 依赖包列表
|
||||
├── README.md # 本文件
|
||||
└── examples/ # 使用示例
|
||||
```
|
||||
|
||||
## 🔧 配置系统
|
||||
|
||||
该工具包使用为GUI集成设计的分层配置系统:
|
||||
|
||||
### 数据配置
|
||||
```python
|
||||
@dataclass
|
||||
class DataConfig:
|
||||
csv_path: str = ""
|
||||
label_column: Union[str, int] = ""
|
||||
spectrum_columns: Optional[Union[str, List]] = None
|
||||
test_size: float = 0.2
|
||||
scale_method: str = 'standard'
|
||||
```
|
||||
|
||||
### 模型配置
|
||||
```python
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
model_names: Optional[List[str]] = None # None = 所有模型
|
||||
tune_hyperparams: bool = True
|
||||
tuning_method: str = 'grid'
|
||||
cv_folds: int = 5
|
||||
```
|
||||
|
||||
### 训练配置
|
||||
```python
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
epochs: int = 100
|
||||
batch_size: int = 32
|
||||
learning_rate: float = 0.001
|
||||
```
|
||||
|
||||
## 🎨 可用模型
|
||||
|
||||
### 回归模型
|
||||
- **线性模型**: 线性回归、岭回归、LASSO、弹性网络、贝叶斯岭回归
|
||||
- **集成模型**: 随机森林、梯度提升、AdaBoost
|
||||
- **提升模型**: XGBoost、LightGBM
|
||||
- **神经网络**: MLP、LSTM、GRU
|
||||
- **专业模型**: SVR、高斯过程、KNN
|
||||
|
||||
### 光谱指数
|
||||
- **植被指数**: NDVI、EVI、ARVI、SAVI 等
|
||||
- **水体指数**: NDWI、MNDWI 等
|
||||
- **矿物指数**: 各种矿物特定指数
|
||||
- **自定义指数**: 用户定义的公式
|
||||
|
||||
## 📊 输出格式
|
||||
|
||||
- **CSV**: 包含统计信息的表格结果
|
||||
- **ENVI**: 标准高光谱格式(.dat + .hdr)
|
||||
- **图像**: PNG图表和可视化
|
||||
- **模型**: Pickle格式的sklearn模型用于部署
|
||||
|
||||
BIN
__pycache__/supervize_cluster.cpython-312.pyc
Normal file
BIN
__pycache__/supervize_cluster.cpython-312.pyc
Normal file
Binary file not shown.
1041
classfication_method/classfication.py
Normal file
1041
classfication_method/classfication.py
Normal file
File diff suppressed because it is too large
Load Diff
1012
cluster_method/cluster.py
Normal file
1012
cluster_method/cluster.py
Normal file
File diff suppressed because it is too large
Load Diff
73
data/roi.xml
Normal file
73
data/roi.xml
Normal file
@ -0,0 +1,73 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<RegionsOfInterest version="1.1">
|
||||
<Region name="ROI #1" color="255,0,0">
|
||||
<GeometryDef>
|
||||
<CoordSysStr>none</CoordSysStr>
|
||||
<Polygon>
|
||||
<Exterior>
|
||||
<LinearRing>
|
||||
<Coordinates>
|
||||
293.49999851 419.74999976 319.75 419.74999976 319.75 442.25 293.49999851 442.25 293.49999851 419.74999976
|
||||
</Coordinates>
|
||||
</LinearRing>
|
||||
</Exterior>
|
||||
</Polygon>
|
||||
</GeometryDef>
|
||||
</Region>
|
||||
<Region name="ROI #2" color="0,128,0">
|
||||
<GeometryDef>
|
||||
<CoordSysStr>none</CoordSysStr>
|
||||
<Polygon>
|
||||
<Exterior>
|
||||
<LinearRing>
|
||||
<Coordinates>
|
||||
684.74999356 430.99999962 708.5 430.99999962 708.5 448.5 684.74999356 448.5 684.74999356 430.99999962
|
||||
</Coordinates>
|
||||
</LinearRing>
|
||||
</Exterior>
|
||||
</Polygon>
|
||||
</GeometryDef>
|
||||
</Region>
|
||||
<Region name="ROI #3" color="0,0,255">
|
||||
<GeometryDef>
|
||||
<CoordSysStr>none</CoordSysStr>
|
||||
<Polygon>
|
||||
<Exterior>
|
||||
<LinearRing>
|
||||
<Coordinates>
|
||||
319.74999818 844.74999438 346 844.74999438 346 862.25 319.74999818 862.25 319.74999818 844.74999438
|
||||
</Coordinates>
|
||||
</LinearRing>
|
||||
</Exterior>
|
||||
</Polygon>
|
||||
</GeometryDef>
|
||||
</Region>
|
||||
<Region name="ROI #4" color="255,255,0">
|
||||
<GeometryDef>
|
||||
<CoordSysStr>none</CoordSysStr>
|
||||
<Polygon>
|
||||
<Exterior>
|
||||
<LinearRing>
|
||||
<Coordinates>
|
||||
672.24999372 854.74999425 697.25 854.74999425 697.25 873.5 672.24999372 873.5 672.24999372 854.74999425
|
||||
</Coordinates>
|
||||
</LinearRing>
|
||||
</Exterior>
|
||||
</Polygon>
|
||||
</GeometryDef>
|
||||
</Region>
|
||||
<Region name="ROI #5" color="0,255,255">
|
||||
<GeometryDef>
|
||||
<CoordSysStr>none</CoordSysStr>
|
||||
<Polygon>
|
||||
<Exterior>
|
||||
<LinearRing>
|
||||
<Coordinates>
|
||||
573.49999497 800.99999494 589.75 800.99999494 589.75 822.25 573.49999497 822.25 573.49999497 800.99999494
|
||||
</Coordinates>
|
||||
</LinearRing>
|
||||
</Exterior>
|
||||
</Polygon>
|
||||
</GeometryDef>
|
||||
</Region>
|
||||
</RegionsOfInterest>
|
||||
678
fliter_method/Smooth_filter.py
Normal file
678
fliter_method/Smooth_filter.py
Normal file
@ -0,0 +1,678 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
from typing import Tuple, Optional, Dict, Union
|
||||
from scipy import ndimage
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
try:
|
||||
import spectral as spy
|
||||
from spectral import envi
|
||||
HAS_SPECTRAL = True
|
||||
except ImportError:
|
||||
HAS_SPECTRAL = False
|
||||
print("警告: 未安装spectral库,只能处理RGB图像")
|
||||
|
||||
|
||||
class HyperspectralImageFilter:
|
||||
"""高光谱图像平滑滤波器
|
||||
|
||||
使用示例:
|
||||
# 创建滤波器实例
|
||||
filter_obj = HyperspectralImageFilter()
|
||||
|
||||
# 处理图像 - 均值滤波
|
||||
dat_path, hdr_path = filter_obj.process_image(
|
||||
input_path='input.hdr',
|
||||
output_path='output_mean',
|
||||
filter_type='mean',
|
||||
kernel_size=5
|
||||
)
|
||||
|
||||
# 处理图像 - 高斯滤波
|
||||
dat_path, hdr_path = filter_obj.process_image(
|
||||
input_path='input.hdr',
|
||||
output_path='output_gaussian',
|
||||
filter_type='gaussian',
|
||||
kernel_size=3,
|
||||
sigma=2.0
|
||||
)
|
||||
|
||||
# 处理图像 - 双边滤波
|
||||
dat_path, hdr_path = filter_obj.process_image(
|
||||
input_path='input.hdr',
|
||||
output_path='output_bilateral',
|
||||
filter_type='bilateral',
|
||||
kernel_size=7,
|
||||
sigma_color=50.0,
|
||||
sigma_space=50.0
|
||||
)
|
||||
|
||||
# 或者直接使用apply_filter方法
|
||||
data, header = filter_obj.load_image('input.hdr')
|
||||
filtered_data = filter_obj.apply_filter('mean', kernel_size=5)
|
||||
dat_path, hdr_path = filter_obj.save_envi('output', filtered_data, 'mean', header)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.header = None
|
||||
self.wavelengths = None
|
||||
self.is_hyperspectral = False
|
||||
self._selected_band_index = None # 记录选择的波段索引
|
||||
self._original_dtype = None # 记录原始数据类型
|
||||
self._original_range = None # 记录原始数据范围
|
||||
|
||||
def load_image(self, image_path: str, band_index: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
|
||||
"""
|
||||
加载图像文件
|
||||
|
||||
Parameters:
|
||||
image_path: 图像文件路径 (.hdr/.dat 或 .jpg/.png/.tif)
|
||||
band_index: 对于高光谱图像,指定要处理的波段索引 (None表示使用所有波段)
|
||||
|
||||
Returns:
|
||||
处理后的图像数据和元数据
|
||||
"""
|
||||
try:
|
||||
# 检查文件扩展名
|
||||
file_ext = os.path.splitext(image_path)[1].lower()
|
||||
|
||||
if file_ext in ['.hdr']:
|
||||
# ENVI高光谱图像
|
||||
if not HAS_SPECTRAL:
|
||||
raise ImportError("需要安装spectral库来处理ENVI格式文件")
|
||||
|
||||
return self._load_envi_image(image_path, band_index)
|
||||
|
||||
elif file_ext in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']:
|
||||
# RGB图像
|
||||
return self._load_rgb_image(image_path)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {file_ext}")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载图像失败: {e}")
|
||||
|
||||
def _load_envi_image(self, hdr_path: str, band_index: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
|
||||
"""加载ENVI格式高光谱图像"""
|
||||
try:
|
||||
# 读取ENVI文件
|
||||
img = envi.open(hdr_path)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 提取波长信息
|
||||
wavelengths = None
|
||||
if 'wavelength' in header:
|
||||
try:
|
||||
wavelengths = np.array([float(w) for w in header['wavelength']])
|
||||
except:
|
||||
wavelengths = None
|
||||
|
||||
# 如果指定了波段,选择特定波段
|
||||
if band_index is not None:
|
||||
if band_index < 0 or band_index >= data.shape[2]:
|
||||
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {data.shape[2]-1}]")
|
||||
|
||||
print(f"原始数据形状: {data.shape}")
|
||||
# 记录选择的波段索引,用于头文件生成
|
||||
self._selected_band_index = band_index
|
||||
|
||||
data = data[:, :, band_index]
|
||||
print(f"选择波段 {band_index} 后的数据形状: {data.shape}")
|
||||
|
||||
# 如果结果是形状 (H, W, 1),压缩为 (H, W)
|
||||
if len(data.shape) == 3 and data.shape[2] == 1:
|
||||
data = np.squeeze(data, axis=2) # 压缩最后一个维度
|
||||
print(f"压缩单波段数据为2D: {data.shape}")
|
||||
|
||||
print(f"选择波段 {band_index} 进行处理")
|
||||
if wavelengths is not None:
|
||||
print(f"波长: {wavelengths[band_index]:.1f} nm")
|
||||
# 选择了特定波段,数据变为2D
|
||||
self.is_hyperspectral = False
|
||||
else:
|
||||
print("处理所有波段")
|
||||
self.is_hyperspectral = True
|
||||
self._selected_band_index = None
|
||||
|
||||
# 记录原始数据类型和范围
|
||||
self._original_dtype = data.dtype
|
||||
if data.size > 0: # 确保数据不为空
|
||||
self._original_range = (data.min(), data.max())
|
||||
print(f"原始数据范围: {self._original_range[0]} - {self._original_range[1]}")
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.wavelengths = wavelengths
|
||||
|
||||
print(f"成功加载ENVI图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||||
if wavelengths is not None and band_index is not None:
|
||||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载ENVI图像失败: {e}")
|
||||
|
||||
def _load_rgb_image(self, image_path: str) -> Tuple[np.ndarray, Dict]:
|
||||
"""加载RGB图像"""
|
||||
try:
|
||||
# 使用OpenCV读取图像
|
||||
data = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if data is None:
|
||||
raise ValueError(f"无法读取图像文件: {image_path}")
|
||||
|
||||
# 转换为RGB格式 (OpenCV默认是BGR)
|
||||
if len(data.shape) == 3 and data.shape[2] == 3:
|
||||
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 记录原始数据类型和范围
|
||||
self._original_dtype = data.dtype
|
||||
if data.size > 0:
|
||||
self._original_range = (data.min(), data.max())
|
||||
print(f"原始数据范围: {self._original_range[0]} - {self._original_range[1]}")
|
||||
|
||||
# 创建基本的元数据
|
||||
header = {
|
||||
'samples': data.shape[1],
|
||||
'lines': data.shape[0],
|
||||
'bands': data.shape[2] if len(data.shape) == 3 else 1,
|
||||
'data_type': self._get_envi_data_type(data.dtype),
|
||||
'interleave': 'bsq',
|
||||
'byte_order': 0
|
||||
}
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.wavelengths = None
|
||||
self.is_hyperspectral = False
|
||||
|
||||
print(f"成功加载RGB图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载RGB图像失败: {e}")
|
||||
|
||||
def _get_envi_data_type(self, dtype: np.dtype) -> int:
|
||||
"""获取ENVI数据类型代码"""
|
||||
if dtype == np.uint8:
|
||||
return 1
|
||||
elif dtype == np.int16:
|
||||
return 2
|
||||
elif dtype == np.int32:
|
||||
return 3
|
||||
elif dtype == np.float32:
|
||||
return 4
|
||||
elif dtype == np.float64:
|
||||
return 5
|
||||
elif dtype == np.uint16:
|
||||
return 12
|
||||
elif dtype == np.uint32:
|
||||
return 13
|
||||
else:
|
||||
return 4 # 默认float32
|
||||
|
||||
def apply_filter(self, filter_type: str, **kwargs) -> np.ndarray:
|
||||
"""
|
||||
应用指定的滤波器
|
||||
|
||||
Parameters:
|
||||
filter_type: 滤波器类型 ('mean', 'median', 'gaussian', 'bilateral')
|
||||
**kwargs: 各滤波器的超参数
|
||||
- kernel_size: 内核大小 (奇数,默认3)
|
||||
- sigma: 高斯滤波的标准差 (默认1.0)
|
||||
- sigma_color: 双边滤波的颜色空间标准差 (默认75.0)
|
||||
- sigma_space: 双边滤波的空间标准差 (默认75.0)
|
||||
|
||||
Returns:
|
||||
滤波后的图像数据
|
||||
"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先加载图像数据")
|
||||
|
||||
# 获取参数,设置默认值
|
||||
kernel_size = kwargs.get('kernel_size', 3)
|
||||
sigma = kwargs.get('sigma', 1.0)
|
||||
sigma_color = kwargs.get('sigma_color', 75.0)
|
||||
sigma_space = kwargs.get('sigma_space', 75.0)
|
||||
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1 # 确保为奇数
|
||||
|
||||
print(f"应用{filter_type}滤波器,参数: kernel_size={kernel_size}")
|
||||
|
||||
if filter_type.lower() == 'mean':
|
||||
return self._mean_filter(kernel_size)
|
||||
elif filter_type.lower() == 'median':
|
||||
return self._median_filter(kernel_size)
|
||||
elif filter_type.lower() == 'gaussian':
|
||||
print(f" sigma={sigma}")
|
||||
return self._gaussian_filter(kernel_size, sigma)
|
||||
elif filter_type.lower() == 'bilateral':
|
||||
print(f" sigma_color={sigma_color}, sigma_space={sigma_space}")
|
||||
return self._bilateral_filter(kernel_size, sigma_color, sigma_space)
|
||||
else:
|
||||
raise ValueError(f"不支持的滤波器类型: {filter_type}")
|
||||
|
||||
def _mean_filter(self, kernel_size: int) -> np.ndarray:
|
||||
"""均值滤波"""
|
||||
try:
|
||||
print(f"均值滤波开始 - 数据形状: {self.data.shape}, is_hyperspectral: {self.is_hyperspectral}")
|
||||
|
||||
# 创建均值内核
|
||||
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) / (kernel_size * kernel_size)
|
||||
|
||||
# 检查数据维度
|
||||
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
|
||||
# 多波段高光谱图像 - 对每个波段分别处理
|
||||
filtered_data = np.zeros_like(self.data, dtype=np.float32)
|
||||
for band in range(self.data.shape[2]):
|
||||
filtered_data[:, :, band] = ndimage.convolve(self.data[:, :, band].astype(np.float32), kernel)
|
||||
print(f"均值滤波完成 - 处理了 {self.data.shape[2]} 个波段")
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
|
||||
# 单波段3D数据 (H, W, 1) - 压缩并处理
|
||||
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
|
||||
filtered_data = ndimage.convolve(data_2d, kernel)
|
||||
# 保持输出为3D形状以保持一致性
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
print(f"均值滤波完成 - 单波段3D数据 (H,W,1) -> (H,W,1)")
|
||||
else:
|
||||
# 2D图像或单波段图像
|
||||
print(f"应用2D均值滤波到形状 {self.data.shape} 的数据")
|
||||
filtered_data = ndimage.convolve(self.data.astype(np.float32), kernel)
|
||||
print(f"2D均值滤波完成 - 输出形状: {filtered_data.shape}")
|
||||
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"均值滤波失败: {e}")
|
||||
|
||||
def _median_filter(self, kernel_size: int) -> np.ndarray:
|
||||
"""中值滤波"""
|
||||
try:
|
||||
# 检查数据维度
|
||||
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
|
||||
# 多波段高光谱图像 - 对每个波段分别处理
|
||||
filtered_data = np.zeros_like(self.data, dtype=self.data.dtype)
|
||||
for band in range(self.data.shape[2]):
|
||||
filtered_data[:, :, band] = ndimage.median_filter(self.data[:, :, band], size=kernel_size)
|
||||
print(f"中值滤波完成 - 处理了 {self.data.shape[2]} 个波段")
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
|
||||
# 单波段3D数据 (H, W, 1) - 压缩并处理
|
||||
data_2d = np.squeeze(self.data, axis=2)
|
||||
filtered_data = ndimage.median_filter(data_2d, size=kernel_size)
|
||||
# 保持输出为3D形状以保持一致性
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
print(f"中值滤波完成 - 单波段3D数据 (H,W,1) -> (H,W,1)")
|
||||
else:
|
||||
# 2D图像或单波段图像
|
||||
filtered_data = ndimage.median_filter(self.data, size=kernel_size)
|
||||
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"中值滤波失败: {e}")
|
||||
|
||||
def _gaussian_filter(self, kernel_size: int, sigma: float) -> np.ndarray:
|
||||
"""高斯滤波"""
|
||||
try:
|
||||
# 检查数据维度
|
||||
if len(self.data.shape) == 3 and self.data.shape[2] > 1:
|
||||
# 多波段高光谱图像 - 对每个波段分别处理
|
||||
filtered_data = np.zeros_like(self.data, dtype=np.float32)
|
||||
for band in range(self.data.shape[2]):
|
||||
filtered_data[:, :, band] = ndimage.gaussian_filter(
|
||||
self.data[:, :, band].astype(np.float32), sigma=sigma
|
||||
)
|
||||
print(f"高斯滤波完成 (sigma={sigma}) - 处理了 {self.data.shape[2]} 个波段")
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
|
||||
# 单波段3D数据 (H, W, 1) - 压缩并处理
|
||||
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
|
||||
filtered_data = ndimage.gaussian_filter(data_2d, sigma=sigma)
|
||||
# 保持输出为3D形状以保持一致性
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
print(f"高斯滤波完成 (sigma={sigma}) - 单波段3D数据 (H,W,1) -> (H,W,1)")
|
||||
else:
|
||||
# 2D图像或单波段图像
|
||||
filtered_data = ndimage.gaussian_filter(self.data.astype(np.float32), sigma=sigma)
|
||||
print(f"高斯滤波完成 (sigma={sigma})")
|
||||
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"高斯滤波失败: {e}")
|
||||
|
||||
def _bilateral_filter(self, kernel_size: int, sigma_color: float, sigma_space: float) -> np.ndarray:
|
||||
"""双边滤波"""
|
||||
try:
|
||||
# 双边滤波主要用于RGB图像或单波段图像
|
||||
# 对于高光谱数据,我们可以对每个波段分别应用双边滤波
|
||||
|
||||
if len(self.data.shape) == 3 and self.data.shape[2] > 1 and self.is_hyperspectral:
|
||||
# 多波段高光谱图像 - 对每个波段分别处理
|
||||
filtered_data = np.zeros_like(self.data, dtype=np.float32)
|
||||
for band in range(self.data.shape[2]):
|
||||
band_data = self.data[:, :, band].astype(np.float32)
|
||||
# 归一化到0-255范围进行双边滤波
|
||||
if band_data.max() > band_data.min():
|
||||
band_norm = ((band_data - band_data.min()) / (band_data.max() - band_data.min()) * 255).astype(np.uint8)
|
||||
filtered_norm = cv2.bilateralFilter(band_norm, kernel_size, sigma_color, sigma_space)
|
||||
# 恢复原始范围
|
||||
filtered_data[:, :, band] = filtered_norm.astype(np.float32) / 255 * (band_data.max() - band_data.min()) + band_data.min()
|
||||
else:
|
||||
filtered_data[:, :, band] = band_data
|
||||
|
||||
print(f"双边滤波完成 - 处理了 {self.data.shape[2]} 个波段")
|
||||
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 3:
|
||||
# RGB图像 - 直接使用OpenCV的双边滤波
|
||||
filtered_data = cv2.bilateralFilter(self.data, kernel_size, sigma_color, sigma_space)
|
||||
print("RGB双边滤波完成")
|
||||
|
||||
elif len(self.data.shape) == 3 and self.data.shape[2] == 1:
|
||||
# 单波段3D数据 (H, W, 1) - 压缩并处理
|
||||
data_2d = np.squeeze(self.data, axis=2).astype(np.float32)
|
||||
if data_2d.max() > data_2d.min():
|
||||
data_norm = ((data_2d - data_2d.min()) / (data_2d.max() - data_2d.min()) * 255).astype(np.uint8)
|
||||
filtered_norm = cv2.bilateralFilter(data_norm, kernel_size, sigma_color, sigma_space)
|
||||
filtered_data = filtered_norm.astype(np.float32) / 255 * (data_2d.max() - data_2d.min()) + data_2d.min()
|
||||
# 保持输出为3D形状以保持一致性
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
else:
|
||||
filtered_data = self.data.copy()
|
||||
print("单波段双边滤波完成 (3D -> 3D)")
|
||||
|
||||
else:
|
||||
# 2D单波段图像 - 转换为uint8进行处理
|
||||
data_float = self.data.astype(np.float32)
|
||||
if data_float.max() > data_float.min():
|
||||
data_norm = ((data_float - data_float.min()) / (data_float.max() - data_float.min()) * 255).astype(np.uint8)
|
||||
filtered_norm = cv2.bilateralFilter(data_norm, kernel_size, sigma_color, sigma_space)
|
||||
filtered_data = filtered_norm.astype(np.float32) / 255 * (data_float.max() - data_float.min()) + data_float.min()
|
||||
else:
|
||||
filtered_data = data_float
|
||||
print("单波段双边滤波完成")
|
||||
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"双边滤波失败: {e}")
|
||||
|
||||
def save_envi(self, output_path: str, filtered_data: np.ndarray, filter_type: str,
|
||||
original_header: Dict = None) -> Tuple[str, str]:
|
||||
"""
|
||||
保存滤波结果为ENVI格式
|
||||
|
||||
Parameters:
|
||||
output_path: 输出文件路径(不含扩展名)
|
||||
filtered_data: 滤波后的数据
|
||||
filter_type: 滤波器类型,用于文件名
|
||||
original_header: 原始图像的头文件信息
|
||||
|
||||
Returns:
|
||||
数据文件路径和头文件路径
|
||||
"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 生成文件名
|
||||
base_name = os.path.basename(output_path)
|
||||
dat_path = f"{output_path}_{filter_type}.dat"
|
||||
hdr_path = f"{output_path}_{filter_type}.hdr"
|
||||
|
||||
# 保存数据文件 - 保持与原始数据一致的类型和范围
|
||||
print(f"滤波前数据范围: {filtered_data.min():.2f} - {filtered_data.max():.2f}")
|
||||
print(f"原始数据类型: {self._original_dtype}, 范围: {self._original_range}")
|
||||
|
||||
# 根据原始数据类型决定保存格式
|
||||
if self._original_dtype is not None:
|
||||
# 如果原始数据是整数类型,尝试保持整数格式
|
||||
if np.issubdtype(self._original_dtype, np.integer):
|
||||
# 对于整数类型的原始数据,检查滤波后的数据是否仍在合理范围内
|
||||
if self._original_range is not None:
|
||||
orig_min, orig_max = self._original_range
|
||||
filtered_min, filtered_max = filtered_data.min(), filtered_data.max()
|
||||
|
||||
# 如果滤波后的数据范围仍在原始范围附近,使用原始数据类型
|
||||
if (filtered_min >= orig_min - 100) and (filtered_max <= orig_max + 100):
|
||||
print(f"保持原始整数数据类型: {self._original_dtype}")
|
||||
filtered_data = np.clip(filtered_data, orig_min, orig_max).astype(self._original_dtype)
|
||||
else:
|
||||
print(f"数据范围变化较大,使用float32保存")
|
||||
filtered_data = filtered_data.astype(np.float32)
|
||||
else:
|
||||
filtered_data = filtered_data.astype(np.float32)
|
||||
else:
|
||||
# 原始数据就是浮点数,保持float32
|
||||
filtered_data = filtered_data.astype(np.float32)
|
||||
else:
|
||||
# 默认使用float32
|
||||
filtered_data = filtered_data.astype(np.float32)
|
||||
|
||||
print(f"保存数据类型: {filtered_data.dtype}, 范围: {filtered_data.min():.2f} - {filtered_data.max():.2f}")
|
||||
|
||||
# 确保数据格式适合ENVI标准
|
||||
if len(filtered_data.shape) == 2:
|
||||
# 2D数据 (H, W) -> 3D数据 (H, W, 1)
|
||||
filtered_data = filtered_data[:, :, np.newaxis]
|
||||
print(f"转换2D数据为3D格式以符合ENVI标准: {filtered_data.shape}")
|
||||
|
||||
# ENVI使用BSQ格式存储
|
||||
filtered_data.tofile(dat_path)
|
||||
print(f"数据文件已保存: {dat_path}")
|
||||
|
||||
# 生成头文件
|
||||
self._save_envi_header(hdr_path, filtered_data, filter_type, original_header)
|
||||
print(f"头文件已保存: {hdr_path}")
|
||||
|
||||
return dat_path, hdr_path
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"保存ENVI文件失败: {e}")
|
||||
|
||||
def _save_envi_header(self, hdr_path: str, data: np.ndarray, filter_type: str,
|
||||
original_header: Dict = None):
|
||||
"""保存ENVI标准格式头文件"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||||
f.write("ENVI\n")
|
||||
|
||||
# 描述信息 - 包含处理时间和描述
|
||||
current_time = datetime.now().strftime("%a %b %d %H:%M:%S %Y")
|
||||
f.write("description = {\n")
|
||||
f.write(f" Convolution Result [{current_time}]\n")
|
||||
f.write("}\n")
|
||||
|
||||
# 图像尺寸
|
||||
if len(data.shape) == 3:
|
||||
samples, lines, bands = data.shape[1], data.shape[0], data.shape[2]
|
||||
else:
|
||||
samples, lines, bands = data.shape[1], data.shape[0], 1
|
||||
|
||||
f.write(f"samples = {samples}\n")
|
||||
f.write(f"lines = {lines}\n")
|
||||
f.write(f"bands = {bands}\n")
|
||||
f.write("header offset = 0\n")
|
||||
f.write("file type = ENVI Standard\n")
|
||||
|
||||
# 数据类型 - 使用实际保存的数据类型
|
||||
if hasattr(self, '_original_dtype') and self._original_dtype is not None:
|
||||
# 优先使用原始数据类型(如果范围合适)
|
||||
data_type = self._get_envi_data_type(self._original_dtype)
|
||||
else:
|
||||
data_type = self._get_envi_data_type(data.dtype)
|
||||
f.write(f"data type = {data_type}\n")
|
||||
|
||||
# 交织格式
|
||||
f.write("interleave = bsq\n")
|
||||
|
||||
# 传感器类型
|
||||
if original_header and 'sensor type' in original_header:
|
||||
f.write(f"sensor type = {original_header['sensor type']}\n")
|
||||
else:
|
||||
f.write("sensor type = Unknown\n")
|
||||
|
||||
# 字节顺序
|
||||
f.write("byte order = 0\n")
|
||||
|
||||
# 波长信息
|
||||
if self.wavelengths is not None and len(self.wavelengths) > 0:
|
||||
f.write("wavelength units = Nanometers\n")
|
||||
if bands == 1 and hasattr(self, '_selected_band_index') and self._selected_band_index is not None:
|
||||
# 单波段情况 - 使用选定波段的波长
|
||||
band_idx = min(self._selected_band_index, len(self.wavelengths) - 1)
|
||||
f.write(f"wavelength = {{\n")
|
||||
f.write(f" {self.wavelengths[band_idx]:.6f}\n")
|
||||
f.write("}\n")
|
||||
else:
|
||||
# 多波段情况 - 写入所有波长
|
||||
f.write("wavelength = {\n")
|
||||
wavelength_str = ", ".join([f"{w:.6f}" for w in self.wavelengths[:bands]])
|
||||
f.write(f" {wavelength_str}\n")
|
||||
f.write("}\n")
|
||||
|
||||
# 反射率比例因子
|
||||
f.write("reflectance scale factor = 10000.000000\n")
|
||||
|
||||
# 从原始头文件复制其他重要参数
|
||||
if original_header:
|
||||
# 增益信息
|
||||
if 'gain' in original_header:
|
||||
f.write(f"gain = {original_header['gain']}\n")
|
||||
|
||||
# 分辨率信息
|
||||
binning_keys = ['sample binning', 'spectral binning', 'line binning']
|
||||
for key in binning_keys:
|
||||
if key in original_header:
|
||||
f.write(f"{key} = {original_header[key]}\n")
|
||||
|
||||
# 快门和帧率信息
|
||||
if 'shutter' in original_header:
|
||||
f.write(f"shutter = {original_header['shutter']}\n")
|
||||
if 'framerate' in original_header:
|
||||
f.write(f"framerate = {original_header['framerate']}\n")
|
||||
|
||||
# 设备序列号
|
||||
if 'imager serial number' in original_header:
|
||||
f.write(f"imager serial number = {original_header['imager serial number']}\n")
|
||||
|
||||
# 旋转矩阵
|
||||
if 'rotation' in original_header:
|
||||
rotation = original_header['rotation']
|
||||
f.write(f"rotation = {rotation}\n")
|
||||
|
||||
# 标签
|
||||
if 'label' in original_header:
|
||||
f.write(f"label = {original_header['label']}\n")
|
||||
|
||||
# 处理历史
|
||||
if 'history' in original_header:
|
||||
f.write(f"history = {original_header['history']}\n")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"保存头文件失败: {e}")
|
||||
|
||||
def process_image(self, input_path: str, output_path: str, filter_type: str,
|
||||
band_index: Optional[int] = None, **kwargs) -> Tuple[str, str]:
|
||||
"""
|
||||
完整的图像处理流程
|
||||
|
||||
Parameters:
|
||||
input_path: 输入图像路径
|
||||
output_path: 输出路径(不含扩展名)
|
||||
filter_type: 滤波器类型
|
||||
band_index: 波段索引(可选)
|
||||
**kwargs: 各滤波器的超参数
|
||||
- kernel_size: 内核大小 (奇数,默认3)
|
||||
- sigma: 高斯滤波的标准差 (默认1.0)
|
||||
- sigma_color: 双边滤波的颜色空间标准差 (默认75.0)
|
||||
- sigma_space: 双边滤波的空间标准差 (默认75.0)
|
||||
|
||||
Returns:
|
||||
数据文件路径和头文件路径
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("高光谱图像平滑滤波器")
|
||||
print("=" * 50)
|
||||
|
||||
# 加载图像
|
||||
print(f"加载图像: {input_path}")
|
||||
data, header = self.load_image(input_path, band_index)
|
||||
|
||||
# 应用滤波器
|
||||
print(f"应用{filter_type}滤波...")
|
||||
filtered_data = self.apply_filter(filter_type, **kwargs)
|
||||
|
||||
# 保存结果
|
||||
print(f"保存结果到: {output_path}")
|
||||
dat_path, hdr_path = self.save_envi(output_path, filtered_data, filter_type, header)
|
||||
|
||||
print("=" * 50)
|
||||
print("处理完成!")
|
||||
print(f"数据文件: {dat_path}")
|
||||
print(f"头文件: {hdr_path}")
|
||||
print("=" * 50)
|
||||
|
||||
return dat_path, hdr_path
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 命令行接口"""
|
||||
|
||||
input = r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr"
|
||||
output = r"E:\code\spectronon\single_classsfication\fliter"
|
||||
filter_type = 'median'
|
||||
band = 20
|
||||
kernel_size = 3
|
||||
sigma = 1.0
|
||||
sigma_color = 75.0
|
||||
sigma_space = 75.0
|
||||
|
||||
# 构建超参数字典
|
||||
kwargs = {
|
||||
'kernel_size': kernel_size,
|
||||
'sigma': sigma,
|
||||
'sigma_color': sigma_color,
|
||||
'sigma_space': sigma_space
|
||||
}
|
||||
|
||||
try:
|
||||
# 创建滤波器实例
|
||||
filter_obj = HyperspectralImageFilter()
|
||||
|
||||
# 处理图像
|
||||
filter_obj.process_image(
|
||||
input_path=input,
|
||||
output_path=output,
|
||||
filter_type=filter_type,
|
||||
band_index=band,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # 运行原始调试
|
||||
# test_data_preservation() # 测试数据格式保持
|
||||
|
||||
758
fliter_method/morphological_fliter.py
Normal file
758
fliter_method/morphological_fliter.py
Normal file
@ -0,0 +1,758 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import warnings
|
||||
from typing import Tuple, List, Dict, Optional, Union
|
||||
from pathlib import Path
|
||||
import spectral as spy
|
||||
from spectral import envi
|
||||
import matplotlib.pyplot as plt
|
||||
from skimage import morphology
|
||||
from scipy import ndimage
|
||||
import struct
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class HyperspectralMorphologyProcessor:
|
||||
"""高光谱图像形态学处理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.header = None
|
||||
self.wavelengths = None
|
||||
self.shape = None
|
||||
self.dtype = None
|
||||
|
||||
def load_hyperspectral_image(self, file_path: str, format_type: str = None) -> Tuple[np.ndarray, Dict]:
|
||||
"""
|
||||
加载高光谱图像(支持bil、bip、bsq、dat格式)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
file_path : str
|
||||
图像文件路径(可以是.hdr、.bil、.bip、.bsq、.dat)
|
||||
format_type : str, optional
|
||||
格式类型('bil'、'bip'、'bsq'、'dat'),如果为None则自动检测
|
||||
|
||||
Returns:
|
||||
--------
|
||||
data : np.ndarray
|
||||
高光谱数据立方体 (rows, cols, bands)
|
||||
header : Dict
|
||||
头文件信息
|
||||
"""
|
||||
try:
|
||||
# 自动检测格式
|
||||
if format_type is None:
|
||||
format_type = self._detect_format(file_path)
|
||||
|
||||
# 如果是.hdr文件,直接使用spectral加载
|
||||
if file_path.lower().endswith('.hdr'):
|
||||
img = envi.open(file_path)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 如果是其他格式,尝试查找对应的.hdr文件
|
||||
else:
|
||||
# 查找可能的hdr文件
|
||||
hdr_candidates = [
|
||||
file_path.rsplit('.', 1)[0] + '.hdr',
|
||||
file_path + '.hdr',
|
||||
file_path[:-4] + '.hdr' if len(file_path) > 4 else None
|
||||
]
|
||||
|
||||
hdr_file = None
|
||||
for candidate in hdr_candidates:
|
||||
if candidate and os.path.exists(candidate):
|
||||
hdr_file = candidate
|
||||
break
|
||||
|
||||
if hdr_file is None:
|
||||
raise FileNotFoundError(f"未找到头文件(.hdr)用于 {file_path}")
|
||||
|
||||
# 加载图像
|
||||
img = envi.open(hdr_file)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 根据格式调整数据布局
|
||||
if format_type.lower() == 'bil':
|
||||
# BIL: 波段交错按行
|
||||
data = self._convert_to_bil(data)
|
||||
elif format_type.lower() == 'bip':
|
||||
# BIP: 波段交错按像素 (spectral默认)
|
||||
pass # spectral默认就是BIP
|
||||
elif format_type.lower() == 'bsq':
|
||||
# BSQ: 波段顺序
|
||||
data = self._convert_to_bsq(data)
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.shape = data.shape
|
||||
self.dtype = data.dtype
|
||||
|
||||
# 提取波长信息
|
||||
if 'wavelength' in header:
|
||||
self.wavelengths = np.array([float(w) for w in header['wavelength']])
|
||||
else:
|
||||
self.wavelengths = None
|
||||
|
||||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}, 格式={format_type}")
|
||||
if self.wavelengths is not None:
|
||||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载图像失败: {e}")
|
||||
|
||||
def _detect_format(self, file_path: str) -> str:
|
||||
"""自动检测图像格式"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
|
||||
if ext == '.hdr':
|
||||
# 读取hdr文件内容检测数据格式
|
||||
with open(file_path, 'r') as f:
|
||||
content = f.read()
|
||||
if 'interleave = bil' in content.lower():
|
||||
return 'bil'
|
||||
elif 'interleave = bsq' in content.lower():
|
||||
return 'bsq'
|
||||
elif 'interleave = bip' in content.lower():
|
||||
return 'bip'
|
||||
else:
|
||||
# 默认假设为BIP
|
||||
return 'bip'
|
||||
elif ext in ['.bil', '.bip', '.bsq', '.dat']:
|
||||
# 从扩展名判断
|
||||
return ext[1:] # 去掉点号
|
||||
else:
|
||||
# 默认假设为BIP
|
||||
return 'bip'
|
||||
|
||||
def _convert_to_bil(self, data: np.ndarray) -> np.ndarray:
|
||||
"""将数据转换为BIL格式"""
|
||||
rows, cols, bands = data.shape
|
||||
bil_data = np.zeros((rows, bands, cols), dtype=data.dtype)
|
||||
for i in range(rows):
|
||||
for b in range(bands):
|
||||
bil_data[i, b, :] = data[i, :, b]
|
||||
return bil_data
|
||||
|
||||
def _convert_to_bsq(self, data: np.ndarray) -> np.ndarray:
|
||||
"""将数据转换为BSQ格式"""
|
||||
rows, cols, bands = data.shape
|
||||
bsq_data = np.zeros((bands, rows, cols), dtype=data.dtype)
|
||||
for b in range(bands):
|
||||
bsq_data[b, :, :] = data[:, :, b]
|
||||
return bsq_data
|
||||
|
||||
def extract_band(self, band_index: int) -> np.ndarray:
|
||||
"""
|
||||
提取指定波段
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
band_index : int
|
||||
波段索引(从0开始)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
band_data : np.ndarray
|
||||
单波段图像 (rows, cols)
|
||||
"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先加载图像数据")
|
||||
|
||||
if band_index < 0 or band_index >= self.shape[2]:
|
||||
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {self.shape[2]-1}]")
|
||||
|
||||
# 根据数据布局提取波段
|
||||
if len(self.data.shape) == 3:
|
||||
# 标准形状 (rows, cols, bands)
|
||||
# 使用squeeze()来确保移除单维度
|
||||
band_data = np.squeeze(self.data[:, :, band_index])
|
||||
print(f"提取前形状: {self.data.shape}, band_index: {band_index}")
|
||||
print(f"提取后初始形状: {band_data.shape}")
|
||||
|
||||
# 确保最终是2D
|
||||
if len(band_data.shape) != 2:
|
||||
raise ValueError(f"无法提取2D波段数据,最终形状: {band_data.shape}")
|
||||
|
||||
elif len(self.data.shape) == 2:
|
||||
# 已经是单波段
|
||||
band_data = self.data
|
||||
else:
|
||||
raise ValueError(f"不支持的数据形状: {self.data.shape}")
|
||||
|
||||
print(f"提取波段 {band_index}: 最终形状={band_data.shape}")
|
||||
|
||||
return band_data
|
||||
|
||||
def apply_morphology_operation(self, band_data: np.ndarray,
|
||||
operation: str = 'dilation',
|
||||
se_shape: str = 'disk',
|
||||
se_size: int = 3,
|
||||
**kwargs) -> np.ndarray:
|
||||
"""
|
||||
应用形态学操作
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
band_data : np.ndarray
|
||||
单波段图像数据 (2D)
|
||||
operation : str
|
||||
形态学操作类型:
|
||||
- 'dilation': 膨胀
|
||||
- 'erosion': 腐蚀
|
||||
- 'opening': 开运算
|
||||
- 'closing': 闭运算
|
||||
- 'gradient': 形态学梯度 (膨胀 - 腐蚀)
|
||||
- 'tophat': 顶帽变换 (原图 - 开运算)
|
||||
- 'bottomhat': 底帽变换 (闭运算 - 原图)
|
||||
- 'reconstruction': 形态学重建
|
||||
se_shape : str
|
||||
结构元素形状: 'disk', 'square', 'rectangle', 'diamond'
|
||||
se_size : int
|
||||
结构元素大小
|
||||
|
||||
Returns:
|
||||
--------
|
||||
result : np.ndarray
|
||||
处理后的图像
|
||||
"""
|
||||
# 确保输入数据是2D的
|
||||
if len(band_data.shape) != 2:
|
||||
raise ValueError(f"输入数据必须是2D的,当前形状: {band_data.shape}")
|
||||
|
||||
# 创建结构元素
|
||||
selem = self._create_structuring_element(se_shape, se_size)
|
||||
|
||||
# 转换为浮点数以保证精度
|
||||
data_float = band_data.astype(np.float32)
|
||||
|
||||
# 应用形态学操作
|
||||
if operation == 'dilation':
|
||||
result = ndimage.grey_dilation(data_float, footprint=selem)
|
||||
|
||||
elif operation == 'erosion':
|
||||
result = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
|
||||
elif operation == 'opening':
|
||||
# 开运算: 先腐蚀后膨胀
|
||||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
result = ndimage.grey_dilation(eroded, footprint=selem)
|
||||
|
||||
elif operation == 'closing':
|
||||
# 闭运算: 先膨胀后腐蚀
|
||||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||||
result = ndimage.grey_erosion(dilated, footprint=selem)
|
||||
|
||||
elif operation == 'gradient':
|
||||
# 形态学梯度: 膨胀 - 腐蚀
|
||||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
result = dilated - eroded
|
||||
|
||||
elif operation == 'tophat':
|
||||
# 顶帽变换: 原图 - 开运算
|
||||
eroded = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
opened = ndimage.grey_dilation(eroded, footprint=selem)
|
||||
result = data_float - opened
|
||||
|
||||
elif operation == 'bottomhat':
|
||||
# 底帽变换: 闭运算 - 原图
|
||||
dilated = ndimage.grey_dilation(data_float, footprint=selem)
|
||||
closed = ndimage.grey_erosion(dilated, footprint=selem)
|
||||
result = closed - data_float
|
||||
|
||||
elif operation == 'reconstruction':
|
||||
# 形态学重建(基于膨胀的重建)
|
||||
# 使用标记图像(这里使用腐蚀后的图像作为标记)
|
||||
marker = ndimage.grey_erosion(data_float, footprint=selem)
|
||||
result = self._morphological_reconstruction(marker, data_float, selem)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的形态学操作: {operation}")
|
||||
|
||||
# 转换为原始数据类型
|
||||
if operation in ['gradient', 'tophat', 'bottomhat']:
|
||||
# 这些操作可能产生负值,保持浮点类型
|
||||
return result
|
||||
else:
|
||||
return self._convert_to_original_type(result, band_data.dtype)
|
||||
|
||||
def _create_structuring_element(self, shape: str, size: int) -> np.ndarray:
|
||||
"""创建结构元素"""
|
||||
if shape == 'disk':
|
||||
# 创建圆形结构元素
|
||||
radius = size // 2
|
||||
y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
|
||||
mask = x*x + y*y <= radius*radius
|
||||
selem = np.zeros((2*radius+1, 2*radius+1), dtype=bool)
|
||||
selem[mask] = True
|
||||
elif shape == 'square':
|
||||
# 创建方形结构元素
|
||||
selem = np.ones((size, size), dtype=bool)
|
||||
elif shape == 'rectangle':
|
||||
# 创建矩形结构元素
|
||||
selem = np.ones((size, size*2), dtype=bool)
|
||||
elif shape == 'diamond':
|
||||
# 创建菱形结构元素
|
||||
selem = morphology.diamond(size)
|
||||
else:
|
||||
raise ValueError(f"不支持的结构元素形状: {shape}")
|
||||
|
||||
return selem
|
||||
|
||||
def _morphological_reconstruction(self, marker: np.ndarray, mask: np.ndarray,
|
||||
selem: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
形态学重建(基于膨胀)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
marker : np.ndarray
|
||||
标记图像
|
||||
mask : np.ndarray
|
||||
掩模图像
|
||||
selem : np.ndarray
|
||||
结构元素
|
||||
|
||||
Returns:
|
||||
--------
|
||||
reconstructed : np.ndarray
|
||||
重建后的图像
|
||||
"""
|
||||
# 确保标记图像不超过掩模图像
|
||||
marker = np.minimum(marker, mask)
|
||||
|
||||
# 迭代膨胀直到收敛
|
||||
prev_marker = np.zeros_like(marker)
|
||||
while not np.array_equal(marker, prev_marker):
|
||||
prev_marker = marker.copy()
|
||||
# 条件膨胀:在掩模限制下膨胀
|
||||
dilated = ndimage.grey_dilation(marker, footprint=selem)
|
||||
marker = np.minimum(dilated, mask)
|
||||
|
||||
return marker
|
||||
|
||||
def _convert_to_original_type(self, data: np.ndarray, original_dtype: np.dtype) -> np.ndarray:
|
||||
"""将数据转换回原始数据类型"""
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
# 对于整数类型,进行裁剪和取整
|
||||
data = np.clip(data, np.iinfo(original_dtype).min, np.iinfo(original_dtype).max)
|
||||
return data.astype(original_dtype)
|
||||
else:
|
||||
# 对于浮点类型,直接转换
|
||||
return data.astype(original_dtype)
|
||||
|
||||
def apply_to_multiple_bands(self, band_indices: List[int], operation: str,
|
||||
se_shape: str = 'disk', se_size: int = 3) -> Dict[int, np.ndarray]:
|
||||
"""
|
||||
对多个波段应用形态学操作
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
band_indices : List[int]
|
||||
波段索引列表
|
||||
operation : str
|
||||
形态学操作类型
|
||||
se_shape : str
|
||||
结构元素形状
|
||||
se_size : int
|
||||
结构元素大小
|
||||
|
||||
Returns:
|
||||
--------
|
||||
results : Dict[int, np.ndarray]
|
||||
每个波段的处理结果
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for band_idx in band_indices:
|
||||
print(f"处理波段 {band_idx}...")
|
||||
band_data = self.extract_band(band_idx)
|
||||
# extract_band 现在应该总是返回2D数据,但为了安全起见检查一下
|
||||
if len(band_data.shape) != 2:
|
||||
raise ValueError(f"波段 {band_idx} 数据不是2D的,形状: {band_data.shape}")
|
||||
result = self.apply_morphology_operation(band_data, operation, se_shape, se_size)
|
||||
results[band_idx] = result
|
||||
|
||||
return results
|
||||
|
||||
def save_as_envi(self, data: np.ndarray, output_path: str,
|
||||
description: str = "形态学处理结果") -> None:
|
||||
"""
|
||||
保存为ENVI格式的dat和hdr文件
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
data : np.ndarray
|
||||
要保存的数据(单波段2D数组)
|
||||
output_path : str
|
||||
输出文件路径(不含扩展名)
|
||||
description : str
|
||||
图像描述
|
||||
"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
# 确保输出文件扩展名为.dat
|
||||
if not output_path.lower().endswith('.dat'):
|
||||
dat_path = output_path + '.dat'
|
||||
else:
|
||||
dat_path = output_path
|
||||
|
||||
# 保存为二进制dat文件
|
||||
data.tofile(dat_path)
|
||||
|
||||
# 创建对应的hdr文件
|
||||
hdr_path = dat_path.rsplit('.', 1)[0] + '.hdr'
|
||||
self._create_morphology_hdr_file(hdr_path, data.shape, description, data.dtype)
|
||||
|
||||
print(f"ENVI格式已保存:")
|
||||
print(f" 数据文件: {dat_path}")
|
||||
print(f" 头文件: {hdr_path}")
|
||||
print(f" 数据形状: {data.shape}, 数据类型: {data.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"保存ENVI文件失败: {e}")
|
||||
|
||||
def _create_morphology_hdr_file(self, hdr_path: str, data_shape: Tuple,
|
||||
description: str, data_dtype: np.dtype = None) -> None:
|
||||
"""
|
||||
创建形态学处理结果的ENVI头文件
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
hdr_path : str
|
||||
头文件路径
|
||||
data_shape : Tuple
|
||||
数据形状
|
||||
description : str
|
||||
图像描述
|
||||
data_dtype : np.dtype, optional
|
||||
数据类型,如果为None则使用默认值
|
||||
"""
|
||||
try:
|
||||
# 从原始头文件获取基本信息(如果有的话)
|
||||
if self.header is not None:
|
||||
# 复制原始头文件的关键信息
|
||||
samples = self.header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
|
||||
lines = self.header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
|
||||
bands = 1 # 形态学处理结果是单波段
|
||||
interleave = 'bsq'
|
||||
|
||||
# 根据数据类型确定ENVI数据类型
|
||||
if data_dtype is not None:
|
||||
dtype = data_dtype
|
||||
else:
|
||||
dtype = np.float32 # 默认类型
|
||||
|
||||
if np.issubdtype(dtype, np.float32):
|
||||
data_type = 4 # float32
|
||||
elif np.issubdtype(dtype, np.float64):
|
||||
data_type = 5 # float64
|
||||
elif np.issubdtype(dtype, np.int32):
|
||||
data_type = 3 # int32
|
||||
elif np.issubdtype(dtype, np.int16):
|
||||
data_type = 2 # int16
|
||||
elif np.issubdtype(dtype, np.uint16):
|
||||
data_type = 12 # uint16
|
||||
else:
|
||||
data_type = 4 # 默认float32
|
||||
|
||||
byte_order = self.header.get('byte order', 0)
|
||||
wavelength_units = self.header.get('wavelength units', 'nm')
|
||||
else:
|
||||
# 默认值(适用于2D形态学处理结果)
|
||||
if len(data_shape) == 2:
|
||||
lines, samples = data_shape
|
||||
else:
|
||||
lines = data_shape[0]
|
||||
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
|
||||
bands = 1
|
||||
interleave = 'bsq'
|
||||
data_type = 4 # float32(形态学处理通常涉及浮点数)
|
||||
byte_order = 0
|
||||
wavelength_units = 'Unknown'
|
||||
|
||||
# 写入hdr文件
|
||||
with open(hdr_path, 'w') as f:
|
||||
f.write("ENVI\n")
|
||||
f.write("description = {\n")
|
||||
f.write(f" {description}\n")
|
||||
f.write(" 单波段形态学处理结果\n")
|
||||
f.write("}\n")
|
||||
f.write(f"samples = {samples}\n")
|
||||
f.write(f"lines = {lines}\n")
|
||||
f.write(f"bands = {bands}\n")
|
||||
f.write(f"header offset = 0\n")
|
||||
f.write(f"file type = ENVI Standard\n")
|
||||
f.write(f"data type = {data_type}\n")
|
||||
f.write(f"interleave = {interleave}\n")
|
||||
f.write(f"byte order = {byte_order}\n")
|
||||
|
||||
# 如果有波长信息,添加虚拟波长
|
||||
if self.header and 'wavelength' in self.header:
|
||||
f.write("wavelength = {\n")
|
||||
f.write(" 形态学处理结果\n")
|
||||
f.write("}\n")
|
||||
f.write(f"wavelength units = {wavelength_units}\n")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"创建HDR文件失败: {e}")
|
||||
|
||||
def visualize_results(self, original_band: np.ndarray,
|
||||
processed_band: np.ndarray,
|
||||
operation_name: str,
|
||||
save_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
可视化原始波段和处理后的波段
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
original_band : np.ndarray
|
||||
原始波段数据
|
||||
processed_band : np.ndarray
|
||||
处理后的波段数据
|
||||
operation_name : str
|
||||
操作名称
|
||||
save_path : str, optional
|
||||
保存图像路径
|
||||
"""
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
|
||||
# 原始图像
|
||||
im1 = axes[0].imshow(original_band, cmap='gray')
|
||||
axes[0].set_title('原始波段')
|
||||
axes[0].axis('off')
|
||||
plt.colorbar(im1, ax=axes[0], shrink=0.8)
|
||||
|
||||
# 处理后的图像
|
||||
im2 = axes[1].imshow(processed_band, cmap='gray')
|
||||
axes[1].set_title(f'{operation_name} 结果')
|
||||
axes[1].axis('off')
|
||||
plt.colorbar(im2, ax=axes[1], shrink=0.8)
|
||||
|
||||
# 差异图像
|
||||
diff = processed_band.astype(np.float32) - original_band.astype(np.float32)
|
||||
im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
|
||||
axes[2].set_title('差异 (处理后 - 原始)')
|
||||
axes[2].axis('off')
|
||||
plt.colorbar(im3, ax=axes[2], shrink=0.8)
|
||||
|
||||
plt.suptitle(f'形态学操作: {operation_name}', fontsize=16)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"可视化结果已保存到: {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:高光谱图像形态学处理"""
|
||||
import argparse
|
||||
|
||||
# 支持的形态学操作列表
|
||||
morph_operations = [
|
||||
'dilation', 'erosion', 'opening', 'closing',
|
||||
'gradient', 'tophat', 'bottomhat', 'reconstruction'
|
||||
]
|
||||
|
||||
# 支持的结构元素形状
|
||||
se_shapes = ['disk', 'square', 'rectangle', 'diamond']
|
||||
|
||||
parser = argparse.ArgumentParser(description='高光谱图像形态学处理工具')
|
||||
parser.add_argument('input_file', help='输入高光谱图像文件路径')
|
||||
parser.add_argument('--format', '-f', default='auto',
|
||||
choices=['auto', 'bil', 'bip', 'bsq', 'dat'],
|
||||
help='图像格式 (默认: auto)')
|
||||
parser.add_argument('--band', '-b', type=int, default=0,
|
||||
help='要处理的波段索引 (默认: 0)')
|
||||
parser.add_argument('--bands', '-B', type=str, default=None,
|
||||
help='多个波段索引,用逗号分隔,如 "0,10,20"')
|
||||
parser.add_argument('--operation', '-o', default='dilation',
|
||||
choices=morph_operations,
|
||||
help=f'形态学操作类型 (默认: dilation)')
|
||||
parser.add_argument('--se_shape', '-s', default='disk',
|
||||
choices=se_shapes,
|
||||
help='结构元素形状 (默认: disk)')
|
||||
parser.add_argument('--se_size', '-S', type=int, default=3,
|
||||
help='结构元素大小 (默认: 3)')
|
||||
parser.add_argument('--output_dir', '-d', default='output',
|
||||
help='输出目录 (默认: output)')
|
||||
parser.add_argument('--output_name', '-n', default='morphology_result',
|
||||
help='输出文件名 (不含扩展名) (默认: morphology_result)')
|
||||
parser.add_argument('--visualize', '-v', action='store_true',
|
||||
help='是否生成可视化结果')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 初始化处理器
|
||||
processor = HyperspectralMorphologyProcessor()
|
||||
|
||||
# 确定格式
|
||||
format_type = None if args.format == 'auto' else args.format
|
||||
|
||||
# 加载图像
|
||||
print(f"加载图像: {args.input_file}")
|
||||
data, header = processor.load_hyperspectral_image(args.input_file, format_type)
|
||||
|
||||
# 确定要处理的波段
|
||||
if args.bands:
|
||||
band_indices = [int(b.strip()) for b in args.bands.split(',')]
|
||||
print(f"处理波段: {band_indices}")
|
||||
single_band = False
|
||||
else:
|
||||
band_indices = [args.band]
|
||||
print(f"处理波段: {args.band}")
|
||||
single_band = True
|
||||
|
||||
# 应用形态学操作
|
||||
if single_band:
|
||||
# 单波段处理
|
||||
band_data = processor.extract_band(args.band)
|
||||
print(f"主函数中波段数据形状: {band_data.shape}")
|
||||
result = processor.apply_morphology_operation(
|
||||
band_data, args.operation, args.se_shape, args.se_size
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{args.band}")
|
||||
processor.save_as_envi(result, output_path,
|
||||
f"{args.operation.capitalize()} 处理结果 - 波段 {args.band}")
|
||||
|
||||
# 可视化
|
||||
if args.visualize:
|
||||
vis_path = os.path.join(args.output_dir, f"visualization_band{args.band}.png")
|
||||
processor.visualize_results(band_data, result, args.operation, vis_path)
|
||||
|
||||
# 打印统计信息
|
||||
print(f"\n=== 统计信息 ===")
|
||||
print(f"原始波段 {args.band}: min={band_data.min():.4f}, max={band_data.max():.4f}, "
|
||||
f"mean={band_data.mean():.4f}")
|
||||
print(f"处理后: min={result.min():.4f}, max={result.max():.4f}, "
|
||||
f"mean={result.mean():.4f}")
|
||||
|
||||
else:
|
||||
# 多波段处理
|
||||
results = processor.apply_to_multiple_bands(
|
||||
band_indices, args.operation, args.se_shape, args.se_size
|
||||
)
|
||||
|
||||
# 保存每个波段的结果
|
||||
for band_idx, result in results.items():
|
||||
output_path = os.path.join(args.output_dir, f"{args.output_name}_band{band_idx}")
|
||||
processor.save_as_envi(result, output_path,
|
||||
f"{args.operation.capitalize()} 处理结果 - 波段 {band_idx}")
|
||||
|
||||
# 如果只有少量波段,可以创建组合可视化
|
||||
if len(band_indices) <= 4 and args.visualize:
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, band_idx in enumerate(band_indices[:4]):
|
||||
original_band = processor.extract_band(band_idx)
|
||||
processed_band = results[band_idx]
|
||||
|
||||
axes[idx].imshow(processed_band, cmap='gray')
|
||||
axes[idx].set_title(f'波段 {band_idx} - {args.operation}')
|
||||
axes[idx].axis('off')
|
||||
|
||||
plt.suptitle(f'多波段形态学处理: {args.operation}', fontsize=16)
|
||||
plt.tight_layout()
|
||||
|
||||
vis_path = os.path.join(args.output_dir, "multi_band_visualization.png")
|
||||
plt.savefig(vis_path, dpi=300, bbox_inches='tight')
|
||||
print(f"多波段可视化结果已保存到: {vis_path}")
|
||||
plt.show()
|
||||
|
||||
print(f"\n✓ 形态学处理完成!")
|
||||
print(f"输出目录: {args.output_dir}")
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
# 使用示例函数
|
||||
def run_example():
|
||||
"""运行示例"""
|
||||
# 示例1:对单个波段进行膨胀操作
|
||||
processor = HyperspectralMorphologyProcessor()
|
||||
|
||||
# 加载图像(请替换为您的实际文件路径)
|
||||
input_file = "path/to/your/hyperspectral_image.hdr" # 或 .bil, .bip, .bsq, .dat
|
||||
data, header = processor.load_hyperspectral_image(input_file, format_type='bip')
|
||||
|
||||
# 提取波段(例如,提取第50个波段)
|
||||
band_idx = 50
|
||||
band_data = processor.extract_band(band_idx)
|
||||
|
||||
# 应用不同的形态学操作
|
||||
operations = ['dilation', 'erosion', 'opening', 'closing',
|
||||
'gradient', 'tophat', 'bottomhat', 'reconstruction']
|
||||
|
||||
results = {}
|
||||
for operation in operations:
|
||||
print(f"\n应用 {operation} 操作...")
|
||||
try:
|
||||
result = processor.apply_morphology_operation(
|
||||
band_data, operation, se_shape='disk', se_size=3
|
||||
)
|
||||
results[operation] = result
|
||||
|
||||
# 保存结果
|
||||
output_path = f"output/{operation}_band{band_idx}"
|
||||
processor.save_as_envi(result, output_path,
|
||||
f"{operation.capitalize()} 处理结果")
|
||||
|
||||
# 可视化
|
||||
processor.visualize_results(band_data, result, operation,
|
||||
save_path=f"output/visualization_{operation}.png")
|
||||
|
||||
except Exception as e:
|
||||
print(f"操作 {operation} 失败: {e}")
|
||||
|
||||
return results
|
||||
if __name__ == '__main__':
|
||||
# 示例1:对单个波段进行膨胀操作
|
||||
processor = HyperspectralMorphologyProcessor()
|
||||
|
||||
# 加载图像(请替换为您的实际文件路径)
|
||||
input_file = r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr" # 或 .bil, .bip, .bsq, .dat
|
||||
data, header = processor.load_hyperspectral_image(input_file, format_type='bip')
|
||||
|
||||
# 提取波段(例如,提取第50个波段)
|
||||
band_idx = 50
|
||||
band_data = processor.extract_band(band_idx)
|
||||
|
||||
# 应用不同的形态学操作
|
||||
operations = ['dilation', 'erosion', 'opening', 'closing',
|
||||
'gradient', 'tophat', 'bottomhat', 'reconstruction']
|
||||
|
||||
results = {}
|
||||
for operation in operations:
|
||||
print(f"\n应用 {operation} 操作...")
|
||||
try:
|
||||
result = processor.apply_morphology_operation(
|
||||
band_data, operation, se_shape='disk', se_size=3
|
||||
)
|
||||
results[operation] = result
|
||||
|
||||
# 保存结果
|
||||
output_path = f"output/{operation}_band{band_idx}"
|
||||
processor.save_as_envi(result, output_path,
|
||||
f"{operation.capitalize()} 处理结果")
|
||||
except Exception as e:
|
||||
print(f"操作 {operation} 失败: {e}")
|
||||
608
preprocessing_method/Preprocessing.py
Normal file
608
preprocessing_method/Preprocessing.py
Normal file
@ -0,0 +1,608 @@
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.preprocessing import MinMaxScaler, StandardScaler
|
||||
import pandas as pd
|
||||
import pywt
|
||||
from copy import deepcopy
|
||||
import joblib # 用于保存和加载模型
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import spectral
|
||||
SPECTRAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
SPECTRAL_AVAILABLE = False
|
||||
print("警告: spectral库不可用,将使用内置方法读取数据")
|
||||
|
||||
|
||||
class HyperspectralPreprocessor:
|
||||
"""
|
||||
高光谱数据预处理器
|
||||
|
||||
支持的文件格式:
|
||||
- CSV文件: 需要指定光谱数据的起始列名
|
||||
- ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (需要spectral库)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.input_format = None # 'csv' 或 'envi'
|
||||
|
||||
def load_data(self, file_path, spectral_start_col=None):
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
参数:
|
||||
file_path: 文件路径
|
||||
spectral_start_col: 对于CSV文件,光谱数据起始列名
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
if suffix == '.csv':
|
||||
if spectral_start_col is None:
|
||||
raise ValueError("对于CSV文件,必须指定spectral_start_col参数")
|
||||
self.input_format = 'csv'
|
||||
return self._load_csv_data(file_path, spectral_start_col)
|
||||
else:
|
||||
# ENVI格式
|
||||
if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']:
|
||||
self.input_format = 'envi'
|
||||
return self._load_envi_data(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件,并确保已安装spectral库。")
|
||||
|
||||
def _load_csv_data(self, file_path, spectral_start_col):
|
||||
"""加载CSV格式的高光谱数据"""
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# 找到光谱起始列的索引
|
||||
if spectral_start_col not in df.columns:
|
||||
raise ValueError(f"列名 '{spectral_start_col}' 不存在于CSV文件中")
|
||||
|
||||
start_idx = df.columns.get_loc(spectral_start_col)
|
||||
|
||||
# 提取光谱数据
|
||||
spectral_data = df.iloc[:, start_idx:].values
|
||||
|
||||
# 提取波长信息(列名)
|
||||
self.wavelengths = df.columns[start_idx:].astype(float).values
|
||||
|
||||
# 保存数据
|
||||
self.data = spectral_data
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"CSV数据加载成功: {self.data_shape}")
|
||||
print(f"波段数量: {len(self.wavelengths)}")
|
||||
if len(self.wavelengths) > 0:
|
||||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||||
|
||||
return self.data
|
||||
|
||||
def _load_envi_data(self, file_path):
|
||||
"""加载ENVI格式的高光谱数据"""
|
||||
original_path = Path(file_path)
|
||||
current_path = original_path
|
||||
|
||||
# 如果是.hdr文件,找到对应的数据文件
|
||||
if current_path.suffix.lower() == '.hdr':
|
||||
data_file = current_path.with_suffix('')
|
||||
if not data_file.exists():
|
||||
for ext in ['.dat', '.bil', '.bsq', '.bip']:
|
||||
candidate = current_path.with_suffix(ext)
|
||||
if candidate.exists():
|
||||
data_file = candidate
|
||||
break
|
||||
current_path = data_file
|
||||
|
||||
print(f"使用spectral库加载: {current_path}")
|
||||
|
||||
# 使用spectral打开图像
|
||||
img = spectral.open_image(str(file_path))
|
||||
|
||||
# 读取所有波段的数据
|
||||
self.data = img.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(img, 'metadata') and 'wavelength' in img.metadata:
|
||||
wavelength_data = img.metadata['wavelength']
|
||||
try:
|
||||
if isinstance(wavelength_data, str):
|
||||
wavelength_data = wavelength_data.strip('{}[]')
|
||||
if ',' in wavelength_data:
|
||||
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()])
|
||||
else:
|
||||
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()])
|
||||
elif isinstance(wavelength_data, list):
|
||||
self.wavelengths = np.array([float(w) for w in wavelength_data])
|
||||
else:
|
||||
self.wavelengths = np.array(wavelength_data, dtype=float)
|
||||
|
||||
print(f"spectral解析波长: {len(self.wavelengths)} 个波段")
|
||||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||||
except Exception as e:
|
||||
print(f"波长解析失败,使用默认值: {e}")
|
||||
self.wavelengths = np.arange(self.data.shape[2])
|
||||
else:
|
||||
self.wavelengths = np.arange(self.data.shape[2])
|
||||
print(f"警告: spectral未找到波长信息,使用默认值: 0-{self.data.shape[2]-1}")
|
||||
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"spectral数据加载成功: {self.data_shape}")
|
||||
print(f"数据类型: {self.data.dtype}")
|
||||
print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]")
|
||||
|
||||
return self.data
|
||||
|
||||
def save_data(self, output_path, data=None, wavelengths=None):
|
||||
"""
|
||||
保存预处理后的数据
|
||||
|
||||
参数:
|
||||
output_path: 输出文件路径
|
||||
data: 要保存的数据 (如果为None,使用self.data)
|
||||
wavelengths: 波长信息 (如果为None,使用self.wavelengths)
|
||||
"""
|
||||
if data is None:
|
||||
data = self.data
|
||||
if wavelengths is None:
|
||||
wavelengths = self.wavelengths
|
||||
|
||||
output_path = Path(output_path)
|
||||
suffix = output_path.suffix.lower()
|
||||
|
||||
if suffix == '.csv':
|
||||
self._save_csv_data(output_path, data, wavelengths)
|
||||
else:
|
||||
# ENVI格式
|
||||
if self.input_format == 'envi' or suffix in ['.bil', '.bsq', '.bip', '.dat']:
|
||||
self._save_envi_data(output_path, data, wavelengths)
|
||||
else:
|
||||
raise ValueError(f"不支持的输出格式: {suffix}")
|
||||
|
||||
def _save_csv_data(self, output_path, data, wavelengths):
|
||||
"""保存为CSV格式"""
|
||||
if wavelengths is not None:
|
||||
# 创建DataFrame,列名使用波长
|
||||
df = pd.DataFrame(data, columns=[f"{w:.1f}" for w in wavelengths])
|
||||
else:
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
df.to_csv(output_path, index=False)
|
||||
print(f"已保存CSV文件: {output_path}")
|
||||
|
||||
def _save_envi_data(self, output_path, data, wavelengths):
|
||||
"""保存为ENVI格式"""
|
||||
data_file = output_path.with_suffix('.dat')
|
||||
hdr_file = output_path.with_suffix('.hdr')
|
||||
|
||||
# 保存二进制数据
|
||||
data.astype('float32').tofile(str(data_file))
|
||||
|
||||
# 创建ENVI头文件
|
||||
if len(data.shape) == 1:
|
||||
lines = 1
|
||||
samples = data.shape[0]
|
||||
bands = 1
|
||||
elif len(data.shape) == 2:
|
||||
lines = data.shape[0]
|
||||
samples = data.shape[1]
|
||||
bands = 1
|
||||
else:
|
||||
lines = data.shape[0]
|
||||
samples = data.shape[1]
|
||||
bands = data.shape[2]
|
||||
|
||||
header_content = f"""ENVI
|
||||
samples = {samples}
|
||||
lines = {lines}
|
||||
bands = {bands}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 4
|
||||
interleave = bip
|
||||
byte order = 0
|
||||
wavelength units = nm
|
||||
"""
|
||||
|
||||
if wavelengths is not None and len(wavelengths) == bands:
|
||||
wavelength_str = ', '.join([f"{w:.1f}" for w in wavelengths])
|
||||
header_content += f"wavelength = {{{wavelength_str}}}\n"
|
||||
|
||||
with open(hdr_file, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
print(f"已保存ENVI文件: {data_file} 和 {hdr_file}")
|
||||
|
||||
def preprocess(self, method, output_path, **kwargs):
|
||||
"""
|
||||
执行预处理并保存结果
|
||||
|
||||
参数:
|
||||
method: 预处理方法名 ('MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave')
|
||||
output_path: 输出文件路径
|
||||
**kwargs: 传递给预处理方法的参数
|
||||
"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先加载数据")
|
||||
|
||||
# 执行预处理
|
||||
processed_data = self._apply_preprocessing(method, **kwargs)
|
||||
|
||||
# 保存结果
|
||||
self.save_data(output_path, processed_data, self.wavelengths)
|
||||
|
||||
return processed_data
|
||||
|
||||
def _apply_preprocessing(self, method, **kwargs):
|
||||
"""应用预处理方法"""
|
||||
method_funcs = {
|
||||
'MMS': self._MMS,
|
||||
'SS': self._SS,
|
||||
'CT': self._CT,
|
||||
'SNV': self._SNV,
|
||||
'MA': self._MA,
|
||||
'SG': self._SG,
|
||||
'D1': self._D1,
|
||||
'D2': self._D2,
|
||||
'DT': self._DT,
|
||||
'MSC': self._MSC,
|
||||
'wave': self._wave
|
||||
}
|
||||
|
||||
if method not in method_funcs:
|
||||
raise ValueError(f"不支持的预处理方法: {method}")
|
||||
|
||||
return method_funcs[method](**kwargs)
|
||||
|
||||
# 预处理方法实现
|
||||
def _MMS(self, **kwargs):
|
||||
"""最大最小值归一化"""
|
||||
print("执行最大最小值归一化...")
|
||||
scaler = MinMaxScaler()
|
||||
if len(self.data.shape) == 2:
|
||||
# CSV格式: (samples, bands)
|
||||
return scaler.fit_transform(self.data)
|
||||
else:
|
||||
# 图像格式: (rows, cols, bands) -> 需要reshape
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
normalized = scaler.fit_transform(reshaped)
|
||||
return normalized.reshape(original_shape)
|
||||
|
||||
def _SS(self, save_path=None, **kwargs):
|
||||
"""标准化"""
|
||||
print("执行标准化...")
|
||||
scaler = StandardScaler()
|
||||
if len(self.data.shape) == 2:
|
||||
result = scaler.fit_transform(self.data)
|
||||
else:
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
result = scaler.fit_transform(reshaped)
|
||||
result = result.reshape(original_shape)
|
||||
|
||||
if save_path:
|
||||
joblib.dump(scaler, save_path)
|
||||
print(f"Scaler参数已保存到: {save_path}")
|
||||
|
||||
return result
|
||||
|
||||
def _CT(self, **kwargs):
|
||||
"""均值中心化"""
|
||||
print("执行均值中心化...")
|
||||
if len(self.data.shape) == 2:
|
||||
# 2D数据: (samples, bands) - 按行计算均值并中心化
|
||||
mean_vals = np.mean(self.data, axis=1, keepdims=True)
|
||||
return self.data - mean_vals
|
||||
else:
|
||||
# 3D数据: (rows, cols, bands) - 按每个像素的光谱计算均值并中心化
|
||||
mean_vals = np.mean(self.data, axis=2, keepdims=True)
|
||||
return self.data - mean_vals
|
||||
|
||||
def _SNV(self, **kwargs):
|
||||
"""标准正态变换"""
|
||||
print("执行标准正态变换...")
|
||||
if len(self.data.shape) != 2:
|
||||
raise ValueError("SNV方法只支持2D数据")
|
||||
|
||||
# 计算每行的均值和标准差
|
||||
data_average = np.mean(self.data, axis=1, keepdims=True)
|
||||
data_std = np.std(self.data, axis=1, keepdims=True)
|
||||
|
||||
# 避免除零错误
|
||||
data_std = np.where(data_std == 0, 1, data_std)
|
||||
|
||||
# 标准化
|
||||
return (self.data - data_average) / data_std
|
||||
|
||||
def _MA(self, WSZ=11, **kwargs):
|
||||
"""移动平均平滑"""
|
||||
print(f"执行移动平均平滑 (窗口大小: {WSZ})...")
|
||||
output_data = deepcopy(self.data)
|
||||
if len(self.data.shape) == 2:
|
||||
for i in range(output_data.shape[0]):
|
||||
out0 = np.convolve(output_data[i], np.ones(WSZ, dtype=int), 'valid') / WSZ
|
||||
r = np.arange(1, WSZ - 1, 2)
|
||||
start = np.cumsum(output_data[i, :WSZ - 1])[::2] / r
|
||||
stop = (np.cumsum(output_data[i, :-WSZ:-1])[::2] / r)[::-1]
|
||||
output_data[i] = np.concatenate((start, out0, stop))
|
||||
else:
|
||||
for i in range(output_data.shape[0]):
|
||||
for j in range(output_data.shape[1]):
|
||||
spectrum = output_data[i, j, :]
|
||||
out0 = np.convolve(spectrum, np.ones(WSZ, dtype=int), 'valid') / WSZ
|
||||
r = np.arange(1, WSZ - 1, 2)
|
||||
start = np.cumsum(spectrum[:WSZ - 1])[::2] / r
|
||||
stop = (np.cumsum(spectrum[:-WSZ:-1])[::2] / r)[::-1]
|
||||
output_data[i, j, :] = np.concatenate((start, out0, stop))
|
||||
return output_data
|
||||
|
||||
def _SG(self, w=15, p=2, **kwargs):
|
||||
"""Savitzky-Golay平滑滤波"""
|
||||
print(f"执行Savitzky-Golay平滑 (窗口: {w}, 阶数: {p})...")
|
||||
if len(self.data.shape) == 2:
|
||||
return signal.savgol_filter(self.data, w, p)
|
||||
else:
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
filtered = signal.savgol_filter(reshaped, w, p)
|
||||
return filtered.reshape(original_shape)
|
||||
|
||||
def _D1(self, **kwargs):
|
||||
"""一阶导数"""
|
||||
print("执行一阶导数...")
|
||||
if len(self.data.shape) == 2:
|
||||
n, p = self.data.shape
|
||||
output_data = np.ones((n, p - 1))
|
||||
for i in range(n):
|
||||
output_data[i] = np.diff(self.data[i])
|
||||
else:
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
n, p = reshaped.shape
|
||||
diff_data = np.ones((n, p - 1))
|
||||
for i in range(n):
|
||||
diff_data[i] = np.diff(reshaped[i])
|
||||
output_data = diff_data.reshape(original_shape[0], original_shape[1], p - 1)
|
||||
|
||||
# 更新波长信息(减少一个波段)
|
||||
if self.wavelengths is not None and len(self.wavelengths) > 1:
|
||||
self.wavelengths = self.wavelengths[:-1]
|
||||
|
||||
return output_data
|
||||
|
||||
def _D2(self, **kwargs):
|
||||
"""二阶导数"""
|
||||
print("执行二阶导数...")
|
||||
if len(self.data.shape) == 2:
|
||||
# 2D数据: (samples, bands)
|
||||
# 计算二阶导数:对原数据求两次差分
|
||||
first_diff = np.diff(self.data, axis=1) # 一阶导数
|
||||
second_diff = np.diff(first_diff, axis=1) # 二阶导数
|
||||
output_data = second_diff
|
||||
elif len(self.data.shape) == 3:
|
||||
# 3D数据: (rows, cols, bands) - 高光谱图像
|
||||
# 对bands维度进行差分
|
||||
first_diff = np.diff(self.data, axis=2) # 一阶导数
|
||||
second_diff = np.diff(first_diff, axis=2) # 二阶导数
|
||||
output_data = second_diff
|
||||
else:
|
||||
raise ValueError("不支持的数据维度")
|
||||
|
||||
# 更新波长信息(减少两个波段)
|
||||
if self.wavelengths is not None and len(self.wavelengths) > 2:
|
||||
self.wavelengths = self.wavelengths[:-2]
|
||||
|
||||
return output_data
|
||||
|
||||
def _DT(self, **kwargs):
|
||||
"""趋势校正"""
|
||||
print("执行趋势校正...")
|
||||
output_data = np.array(self.data)
|
||||
if len(self.data.shape) == 2:
|
||||
length = output_data.shape[1]
|
||||
x = np.asarray(range(length), dtype=np.float32)
|
||||
l = LinearRegression()
|
||||
for i in range(output_data.shape[0]):
|
||||
l.fit(x.reshape(-1, 1), output_data[i].reshape(-1, 1))
|
||||
k = l.coef_
|
||||
b = l.intercept_
|
||||
for j in range(output_data.shape[1]):
|
||||
output_data[i][j] = output_data[i][j] - (j * k + b)
|
||||
else:
|
||||
length = output_data.shape[2]
|
||||
x = np.asarray(range(length), dtype=np.float32)
|
||||
l = LinearRegression()
|
||||
for i in range(output_data.shape[0]):
|
||||
for j in range(output_data.shape[1]):
|
||||
spectrum = output_data[i, j, :]
|
||||
l.fit(x.reshape(-1, 1), spectrum.reshape(-1, 1))
|
||||
k = l.coef_
|
||||
b = l.intercept_
|
||||
for k_idx in range(output_data.shape[2]):
|
||||
output_data[i, j, k_idx] = output_data[i, j, k_idx] - (k_idx * k + b)
|
||||
return output_data
|
||||
|
||||
def _MSC(self, **kwargs):
|
||||
"""多元散射校正"""
|
||||
print("执行多元散射校正...")
|
||||
if len(self.data.shape) == 2:
|
||||
n, p = self.data.shape
|
||||
output_data = np.ones((n, p))
|
||||
mean = np.mean(self.data, axis=0)
|
||||
for i in range(n):
|
||||
y = self.data[i, :]
|
||||
l = LinearRegression()
|
||||
l.fit(mean.reshape(-1, 1), y.reshape(-1, 1))
|
||||
k = l.coef_
|
||||
b = l.intercept_
|
||||
output_data[i, :] = (y - b) / k
|
||||
elif len(self.data.shape) == 3:
|
||||
# 3D数据: (rows, cols, bands) - 高光谱图像
|
||||
rows, cols, bands = self.data.shape
|
||||
output_data = np.zeros((rows, cols, bands), dtype=np.float32)
|
||||
|
||||
# 计算整个图像的平均光谱
|
||||
mean_spectrum = np.mean(self.data, axis=(0, 1)) # 对所有像素求平均
|
||||
|
||||
# 对每个像素进行MSC校正
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
y = self.data[i, j, :]
|
||||
l = LinearRegression()
|
||||
l.fit(mean_spectrum.reshape(-1, 1), y.reshape(-1, 1))
|
||||
k = l.coef_
|
||||
b = l.intercept_
|
||||
# 避免除零错误
|
||||
k = max(k, 1e-10) if k == 0 else k
|
||||
output_data[i, j, :] = (y - b) / k
|
||||
else:
|
||||
raise ValueError("不支持的数据维度")
|
||||
return output_data
|
||||
|
||||
def _wave(self, **kwargs):
|
||||
"""小波变换"""
|
||||
print("执行小波变换...")
|
||||
def wave_single(spectrum):
|
||||
w = pywt.Wavelet('db8')
|
||||
maxlev = pywt.dwt_max_level(len(spectrum), w.dec_len)
|
||||
coeffs = pywt.wavedec(spectrum, 'db8', level=maxlev)
|
||||
threshold = 0.04
|
||||
for i in range(1, len(coeffs)):
|
||||
coeffs[i] = pywt.threshold(coeffs[i], threshold * max(coeffs[i]))
|
||||
return pywt.waverec(coeffs, 'db8')
|
||||
|
||||
if len(self.data.shape) == 2:
|
||||
output_data = None
|
||||
for i in range(self.data.shape[0]):
|
||||
processed = wave_single(self.data[i])
|
||||
if i == 0:
|
||||
output_data = processed
|
||||
else:
|
||||
output_data = np.vstack((output_data, processed))
|
||||
else:
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
output_data = None
|
||||
for i in range(reshaped.shape[0]):
|
||||
processed = wave_single(reshaped[i])
|
||||
if i == 0:
|
||||
output_data = processed
|
||||
else:
|
||||
output_data = np.vstack((output_data, processed))
|
||||
output_data = output_data.reshape(original_shape)
|
||||
|
||||
return output_data
|
||||
|
||||
|
||||
# ===== 便捷函数 =====
|
||||
|
||||
def preprocess_file(input_file, output_file, method, spectral_start_col=None, **kwargs):
|
||||
"""
|
||||
便捷函数:预处理单个文件
|
||||
|
||||
参数:
|
||||
input_file: 输入文件路径
|
||||
output_file: 输出文件路径
|
||||
method: 预处理方法
|
||||
spectral_start_col: CSV文件的光谱起始列名
|
||||
**kwargs: 传递给预处理方法的参数
|
||||
"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.load_data(input_file, spectral_start_col)
|
||||
return processor.preprocess(method, output_file, **kwargs)
|
||||
|
||||
|
||||
# 保持向后兼容的函数
|
||||
def MMS(input_spectrum):
|
||||
"""最大最小值归一化 (向后兼容)"""
|
||||
return HyperspectralPreprocessor()._MMS(input_spectrum)
|
||||
|
||||
def SS(input_spectrum, save_path=None):
|
||||
"""标准化 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._SS(save_path=save_path)
|
||||
|
||||
def CT(input_spectrum):
|
||||
"""均值中心化 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._CT()
|
||||
|
||||
def SNV(input_spectrum):
|
||||
"""标准正态变换 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._SNV()
|
||||
|
||||
def MA(input_spectrum, WSZ=11):
|
||||
"""移动平均平滑 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._MA(WSZ=WSZ)
|
||||
|
||||
def SG(input_spectrum, w=15, p=2):
|
||||
"""Savitzky-Golay平滑滤波 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._SG(w=w, p=p)
|
||||
|
||||
def D1(input_spectrum):
|
||||
"""一阶导数 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._D1()
|
||||
|
||||
def D2(input_spectrum):
|
||||
"""二阶导数 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._D2()
|
||||
|
||||
def DT(input_spectrum):
|
||||
"""趋势校正 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._DT()
|
||||
|
||||
def MSC(input_spectrum):
|
||||
"""多元散射校正 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._MSC()
|
||||
|
||||
def wave(input_spectrum):
|
||||
"""小波变换 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._wave()
|
||||
|
||||
|
||||
# ===== 使用示例 =====
|
||||
if __name__ == "__main__":
|
||||
# 示例1: 处理ENVI格式的高光谱图像
|
||||
print("=== 示例1: 处理ENVI格式图像 ===")
|
||||
processor = HyperspectralPreprocessor()
|
||||
|
||||
|
||||
print("\n=== 示例2: 处理CSV格式数据 ===")
|
||||
try:
|
||||
# 假设CSV文件的格式:第一列是样本ID,后面列是从'400.0'开始的光谱数据
|
||||
processor_csv = HyperspectralPreprocessor()
|
||||
processor_csv.load_data(r"E:\code\spectronon\single_classsfication\tsst\train.csv", spectral_start_col='374.285004')
|
||||
|
||||
# 执行预处理
|
||||
processor_csv.preprocess('SNV', r'E:\code\spectronon\single_classsfication\tsst\output_snv.csv')
|
||||
processor_csv.preprocess('MSC', r'E:\code\spectronon\single_classsfication\tsst\output_msc.csv')
|
||||
|
||||
except Exception as e:
|
||||
print(f"CSV示例失败: {e}")
|
||||
|
||||
|
||||
32
requirements.txt
Normal file
32
requirements.txt
Normal file
@ -0,0 +1,32 @@
|
||||
# 核心科学计算和数据处理
|
||||
numpy>=1.21.0
|
||||
pandas>=1.3.0
|
||||
scipy>=1.7.0
|
||||
scikit-learn>=1.0.0
|
||||
|
||||
# 机器学习和深度学习
|
||||
xgboost>=1.5.0
|
||||
lightgbm>=3.3.0
|
||||
|
||||
# 可选:深度学习框架(选择一个或多个)
|
||||
tensorflow>=2.8.0 # TensorFlow/Keras 模型
|
||||
torch>=1.11.0 # PyTorch 模型
|
||||
torchvision>=0.12.0 # PyTorch 视觉模型
|
||||
|
||||
# 可视化
|
||||
matplotlib>=3.5.0
|
||||
seaborn>=0.11.0
|
||||
|
||||
# 图像处理
|
||||
opencv-python>=4.5.0
|
||||
|
||||
# 光谱数据处理(可选)
|
||||
spectral>=0.22.0 # ENVI 文件处理
|
||||
|
||||
# 工具库
|
||||
joblib>=1.1.0
|
||||
typing-extensions>=4.0.0 # 类型提示
|
||||
|
||||
# 开发和测试(可选)
|
||||
pytest>=6.2.0
|
||||
jupyter>=1.0.0
|
||||
1634
rgression_method/regression.py
Normal file
1634
rgression_method/regression.py
Normal file
File diff suppressed because it is too large
Load Diff
880
spectral_index_method/spectral_index.py
Normal file
880
spectral_index_method/spectral_index.py
Normal file
@ -0,0 +1,880 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from dataclasses import dataclass, field
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import spectral
|
||||
SPECTRAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
SPECTRAL_AVAILABLE = False
|
||||
print("警告: spectral库不可用,将使用内置方法读取数据")
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataConfig:
|
||||
"""数据配置类"""
|
||||
data_file_path: str = ""
|
||||
data_format: str = "auto" # 'csv', 'envi', 'auto'
|
||||
wavelength_column: Optional[int] = None # CSV格式中波长列的索引
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexConfig:
|
||||
"""光谱指数配置类"""
|
||||
spectral_index_csv: str = r"E:\code\spectronon\spectral_index.csv"
|
||||
formula_csv: str = r"E:\code\spectronon\famula.csv"
|
||||
indices_to_calculate: Optional[list] = None # 要计算的指数列表,None表示全部
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig:
|
||||
"""输出配置类"""
|
||||
output_dir: str = "results"
|
||||
save_individual_indices: bool = True
|
||||
save_combined_results: bool = True
|
||||
output_format: str = "csv" # 'csv', 'excel', 'both'
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpectralIndexConfig:
|
||||
"""光谱指数计算完整配置类 - 为GUI对接设计的标准化接口"""
|
||||
data: DataConfig = field(default_factory=DataConfig)
|
||||
indices: IndexConfig = field(default_factory=IndexConfig)
|
||||
output: OutputConfig = field(default_factory=OutputConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
self._validate_parameters()
|
||||
|
||||
def _validate_parameters(self):
|
||||
"""参数校验"""
|
||||
# 数据参数校验
|
||||
if self.data.data_format not in ['csv', 'envi', 'auto']:
|
||||
raise ValueError("Data format must be 'csv', 'envi', or 'auto'")
|
||||
|
||||
# 输出参数校验
|
||||
if self.output.output_format not in ['csv', 'excel', 'both']:
|
||||
raise ValueError("Output format must be 'csv', 'excel', or 'both'")
|
||||
|
||||
# 文件路径校验
|
||||
if not os.path.exists(self.indices.spectral_index_csv):
|
||||
print(f"Warning: Spectral index CSV file not found: {self.indices.spectral_index_csv}")
|
||||
if not os.path.exists(self.indices.formula_csv):
|
||||
print(f"Warning: Formula CSV file not found: {self.indices.formula_csv}")
|
||||
|
||||
@classmethod
|
||||
def create_default(cls) -> 'SpectralIndexConfig':
|
||||
"""创建默认配置"""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def create_quick_analysis(cls, data_file_path: str, indices_to_calculate: Optional[list] = None) -> 'SpectralIndexConfig':
|
||||
"""创建快速分析配置"""
|
||||
config = cls()
|
||||
config.data.data_file_path = data_file_path
|
||||
if indices_to_calculate:
|
||||
config.indices.indices_to_calculate = indices_to_calculate
|
||||
config.output.save_individual_indices = False # 快速分析不保存单个指数
|
||||
return config
|
||||
|
||||
|
||||
class HyperspectralIndexCalculator:
|
||||
"""
|
||||
高光谱指数计算器 - 支持GUI对接的标准化接口
|
||||
支持多种光谱指数的自动计算
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SpectralIndexConfig] = None,
|
||||
spectral_index_csv=None, formula_csv=None):
|
||||
"""
|
||||
初始化光谱指数计算器
|
||||
|
||||
Parameters:
|
||||
config (SpectralIndexConfig, optional): 配置对象,如果为None则使用默认配置
|
||||
spectral_index_csv (str, optional): 波段定义文件路径(向后兼容)
|
||||
formula_csv (str, optional): 光谱指数公式文件路径(向后兼容)
|
||||
"""
|
||||
# 处理向后兼容性
|
||||
if config is None and (spectral_index_csv is not None or formula_csv is not None):
|
||||
# 使用传统参数方式
|
||||
config = SpectralIndexConfig()
|
||||
if spectral_index_csv is not None:
|
||||
config.indices.spectral_index_csv = spectral_index_csv
|
||||
if formula_csv is not None:
|
||||
config.indices.formula_csv = formula_csv
|
||||
|
||||
self.config = config or SpectralIndexConfig()
|
||||
self._validate_config()
|
||||
|
||||
# 加载波段定义
|
||||
self.band_info = self.load_band_info(self.config.indices.spectral_index_csv)
|
||||
|
||||
# 加载光谱指数公式
|
||||
self.formulas = self.load_formulas(self.config.indices.formula_csv)
|
||||
|
||||
# 存储当前加载的数据
|
||||
self.data = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.band_mapping = None
|
||||
|
||||
def update_config(self, config: SpectralIndexConfig):
|
||||
"""
|
||||
更新配置 - 为GUI动态配置预留接口
|
||||
|
||||
Parameters:
|
||||
config (SpectralIndexConfig): 新的配置对象
|
||||
"""
|
||||
self.config = config
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self):
|
||||
"""配置校验"""
|
||||
try:
|
||||
self.config._validate_parameters()
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Configuration validation failed: {e}")
|
||||
|
||||
def load_band_info(self, csv_path):
|
||||
"""加载波段定义信息"""
|
||||
df = pd.read_csv(csv_path)
|
||||
band_info = {}
|
||||
|
||||
for _, row in df.iterrows():
|
||||
name = row['name']
|
||||
band_info[name] = {
|
||||
'min': float(row['min']),
|
||||
'center': float(row['center']),
|
||||
'max': float(row['max'])
|
||||
}
|
||||
|
||||
# 为常用波段名称创建别名映射
|
||||
if name in ['nir', 'red', 'green', 'blue', 'swir1', 'swir2']:
|
||||
band_info[name.upper()] = band_info[name]
|
||||
|
||||
return band_info
|
||||
|
||||
def load_formulas(self, csv_path):
|
||||
"""加载光谱指数公式"""
|
||||
df = pd.read_csv(csv_path)
|
||||
formulas = {}
|
||||
|
||||
for _, row in df.iterrows():
|
||||
index = row['Index']
|
||||
formulas[index] = {
|
||||
'name': row['Name'],
|
||||
'type': eval(row['Type']), # 将字符串列表转换为列表
|
||||
'equation': row['Equation'],
|
||||
'bands': eval(row['Bands']) # 将字符串列表转换为列表
|
||||
}
|
||||
return formulas
|
||||
|
||||
def load_hyperspectral_data(self, file_path):
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
支持格式:
|
||||
- CSV: 第一行为波长,后续为数据
|
||||
- ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (使用spectral库)
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
if suffix == '.csv':
|
||||
return self.load_csv_data(file_path)
|
||||
else:
|
||||
# 使用spectral库
|
||||
if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']:
|
||||
return self.load_with_spectral(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件,并确保已安装spectral库。")
|
||||
|
||||
def load_csv_data(self, file_path):
|
||||
"""加载CSV格式的高光谱数据"""
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# 假设第一行是波长
|
||||
self.wavelengths = df.columns.astype(float).values
|
||||
self.data = df.values
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"CSV数据加载成功: {self.data_shape}")
|
||||
print(f"波段数量: {len(self.wavelengths)}")
|
||||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||||
|
||||
return self.data
|
||||
|
||||
def load_with_spectral(self, file_path):
|
||||
"""
|
||||
使用spectral库加载高光谱数据
|
||||
支持ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件
|
||||
"""
|
||||
original_path = Path(file_path)
|
||||
current_path = original_path
|
||||
|
||||
# 如果是.hdr文件,找到对应的数据文件
|
||||
if current_path.suffix.lower() == '.hdr':
|
||||
data_file = current_path.with_suffix('')
|
||||
if not data_file.exists():
|
||||
# 尝试其他常见扩展名
|
||||
for ext in ['.dat', '.bil', '.bsq', '.bip']:
|
||||
candidate = current_path.with_suffix(ext)
|
||||
if candidate.exists():
|
||||
data_file = candidate
|
||||
break
|
||||
current_path = data_file
|
||||
|
||||
print(f"使用spectral库加载: {current_path}")
|
||||
|
||||
# 使用spectral打开图像
|
||||
img = spectral.open_image(file_path)
|
||||
|
||||
# 读取所有波段的数据
|
||||
# spectral默认返回(行, 列, 波段)的numpy数组
|
||||
self.data = img.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(img, 'metadata') and 'wavelength' in img.metadata:
|
||||
wavelength_data = img.metadata['wavelength']
|
||||
try:
|
||||
# 处理波长数据,可能需要转换为float数组
|
||||
if isinstance(wavelength_data, str):
|
||||
# 如果是字符串,尝试解析
|
||||
wavelength_data = wavelength_data.strip('{}[]')
|
||||
if ',' in wavelength_data:
|
||||
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()])
|
||||
else:
|
||||
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()])
|
||||
elif isinstance(wavelength_data, list):
|
||||
self.wavelengths = np.array([float(w) for w in wavelength_data])
|
||||
else:
|
||||
self.wavelengths = np.array(wavelength_data, dtype=float)
|
||||
|
||||
print(f"spectral解析波长: {len(self.wavelengths)} 个波段")
|
||||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||||
except Exception as e:
|
||||
print(f"波长解析失败,使用默认值: {e}")
|
||||
self.wavelengths = np.arange(self.data.shape[2])
|
||||
else:
|
||||
# 如果没有波长信息,创建默认值
|
||||
self.wavelengths = np.arange(self.data.shape[2])
|
||||
print(f"警告: spectral未找到波长信息,使用默认值: 0-{self.data.shape[2]-1}")
|
||||
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"spectral数据加载成功: {self.data_shape}")
|
||||
print(f"数据类型: {self.data.dtype}")
|
||||
print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]")
|
||||
|
||||
return self.data
|
||||
|
||||
def find_band_index(self, band_name, tolerance=5):
|
||||
"""
|
||||
根据波段名称找到对应的波段索引
|
||||
|
||||
参数:
|
||||
band_name: 波段名称 (如 'w550', 'nir', 'red')
|
||||
tolerance: 容差范围 (nm)
|
||||
|
||||
返回:
|
||||
波段索引 (从0开始)
|
||||
"""
|
||||
if band_name not in self.band_info:
|
||||
# 尝试匹配标准波段名称
|
||||
std_names = {
|
||||
'nir': 'w860', 'red': 'w680', 'green': 'w550',
|
||||
'blue': 'w470', 'swir1': 'w1650', 'swir2': 'w2200',
|
||||
'tm1': 'w485', 'tm2': 'w569', 'tm3': 'w660',
|
||||
'tm4': 'w833', 'tm5': 'w1676', 'tm7': 'w2223'
|
||||
}
|
||||
|
||||
if band_name in std_names:
|
||||
band_name = std_names[band_name]
|
||||
else:
|
||||
raise ValueError(f"未定义的波段名称: {band_name}")
|
||||
|
||||
target_info = self.band_info[band_name]
|
||||
target_center = target_info['center']
|
||||
target_min = target_info['min']
|
||||
target_max = target_info['max']
|
||||
|
||||
# 检查当前数据的波长范围是否包含目标波段
|
||||
data_min = self.wavelengths[0]
|
||||
data_max = self.wavelengths[-1]
|
||||
|
||||
if target_min < data_min or target_max > data_max:
|
||||
raise ValueError(
|
||||
f"波段 {band_name} 超出数据波长范围: "
|
||||
f"目标波段 [{target_min:.1f}-{target_max:.1f}nm], "
|
||||
f"数据范围 [{data_min:.1f}-{data_max:.1f}nm]"
|
||||
)
|
||||
|
||||
# 找到最接近的波长
|
||||
diffs = np.abs(self.wavelengths - target_center)
|
||||
min_idx = np.argmin(diffs)
|
||||
|
||||
if diffs[min_idx] > tolerance:
|
||||
print(
|
||||
f"警告: 波段 {band_name} (目标: {target_center}nm) 匹配到 {self.wavelengths[min_idx]:.1f}nm (差异: {diffs[min_idx]:.1f}nm)")
|
||||
|
||||
return min_idx
|
||||
|
||||
def calculate_index(self, index_name, output_type='image'):
|
||||
"""
|
||||
计算指定光谱指数
|
||||
|
||||
参数:
|
||||
index_name: 光谱指数简称 (如 'NDVI', 'EVI')
|
||||
output_type: 输出类型 ('image'或'values')
|
||||
|
||||
返回:
|
||||
光谱指数图像或数值
|
||||
"""
|
||||
if index_name not in self.formulas:
|
||||
available = list(self.formulas.keys())
|
||||
raise ValueError(f"未找到指数: {index_name}。可用指数: {available}")
|
||||
|
||||
formula_info = self.formulas[index_name]
|
||||
equation = formula_info['equation']
|
||||
required_bands = formula_info['bands']
|
||||
|
||||
print(f"计算指数: {formula_info['name']} ({index_name})")
|
||||
print(f"所需波段: {required_bands}")
|
||||
print(f"公式: {equation}")
|
||||
|
||||
# 检查公式中是否包含临时变量
|
||||
if ';' in equation:
|
||||
# 分离主公式和临时变量定义
|
||||
parts = equation.split(';')
|
||||
main_eq = parts[0].strip()
|
||||
temp_vars = parts[1:]
|
||||
|
||||
# 创建局部变量字典
|
||||
local_vars = {}
|
||||
|
||||
# 先计算临时变量
|
||||
for temp_eq in temp_vars:
|
||||
if '=' in temp_eq:
|
||||
var_name, expr = temp_eq.split('=', 1)
|
||||
var_name = var_name.strip()
|
||||
expr = expr.strip()
|
||||
|
||||
# 计算临时变量
|
||||
temp_value = self.evaluate_expression(expr, required_bands, local_vars)
|
||||
local_vars[var_name] = temp_value
|
||||
else:
|
||||
main_eq = equation.strip()
|
||||
local_vars = {}
|
||||
|
||||
# 计算主公式
|
||||
index_result = self.evaluate_expression(main_eq, required_bands, local_vars)
|
||||
|
||||
# 根据输出类型返回结果
|
||||
if output_type == 'image':
|
||||
return index_result
|
||||
elif output_type == 'values':
|
||||
# 展平为一维数组
|
||||
return index_result.flatten()
|
||||
else:
|
||||
return index_result
|
||||
|
||||
def evaluate_expression(self, expression, required_bands, local_vars):
|
||||
"""
|
||||
评估表达式
|
||||
|
||||
参数:
|
||||
expression: 表达式字符串
|
||||
required_bands: 所需波段列表
|
||||
local_vars: 局部变量字典
|
||||
"""
|
||||
# 提取所有波段变量
|
||||
pattern = r'\b(w\d+|tm\d+|nir|red|green|blue|swir1|swir2|thermal)\b'
|
||||
band_vars = re.findall(pattern, expression)
|
||||
|
||||
# 创建波段数据字典
|
||||
band_data = {}
|
||||
for band_var in set(band_vars):
|
||||
# 检查是否已在局部变量中
|
||||
if band_var in local_vars:
|
||||
continue
|
||||
|
||||
# 查找波段索引
|
||||
band_idx = self.find_band_index(band_var)
|
||||
|
||||
# 获取波段数据
|
||||
if len(self.data_shape) == 2: # CSV格式
|
||||
band_data[band_var] = self.data[:, band_idx]
|
||||
else: # 图像格式
|
||||
band_data[band_var] = self.data[:, :, band_idx]
|
||||
|
||||
# 合并局部变量
|
||||
all_vars = {**band_data, **local_vars}
|
||||
|
||||
# 添加数学函数
|
||||
math_funcs = {
|
||||
'sqrt': np.sqrt,
|
||||
'log': np.log,
|
||||
'log10': np.log10,
|
||||
'exp': np.exp,
|
||||
'sin': np.sin,
|
||||
'cos': np.cos,
|
||||
'tan': np.tan,
|
||||
'abs': np.abs,
|
||||
'pow': np.power
|
||||
}
|
||||
|
||||
# 安全评估表达式
|
||||
try:
|
||||
# 替换表达式中的^为**
|
||||
expression = expression.replace('^', '**')
|
||||
|
||||
# 创建安全环境
|
||||
safe_env = {**all_vars, **math_funcs}
|
||||
|
||||
# 添加numpy函数
|
||||
safe_env['np'] = np
|
||||
|
||||
# 执行表达式
|
||||
result = eval(expression, {"__builtins__": {}}, safe_env)
|
||||
|
||||
# 处理可能的异常值
|
||||
if isinstance(result, np.ndarray):
|
||||
result = np.nan_to_num(result, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"表达式评估失败: {expression}\n错误: {str(e)}")
|
||||
|
||||
def calculate_multiple_indices(self, index_list, output_format='separate'):
|
||||
"""
|
||||
计算多个光谱指数
|
||||
|
||||
参数:
|
||||
index_list: 指数名称列表
|
||||
output_format: 输出格式 ('separate'或'stack')
|
||||
|
||||
返回:
|
||||
光谱指数结果
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for index_name in index_list:
|
||||
try:
|
||||
result = self.calculate_index(index_name, output_type='image')
|
||||
results[index_name] = result
|
||||
print(f"✓ {index_name}: 计算完成 (形状: {result.shape})")
|
||||
except Exception as e:
|
||||
print(f"✗ {index_name}: 计算失败 - {str(e)}")
|
||||
results[index_name] = None
|
||||
|
||||
if output_format == 'stack' and results:
|
||||
# 将所有指数堆叠成一个多波段图像
|
||||
valid_results = [r for r in results.values() if r is not None]
|
||||
if valid_results:
|
||||
return np.stack(valid_results, axis=-1)
|
||||
|
||||
return results
|
||||
|
||||
def save_results(self, results, output_prefix='output'):
|
||||
"""
|
||||
保存计算结果
|
||||
|
||||
参数:
|
||||
results: 计算结果字典或数组
|
||||
output_prefix: 输出文件前缀
|
||||
"""
|
||||
# 根据输入数据形状判断文件格式
|
||||
is_csv_format = len(self.data_shape) == 2
|
||||
|
||||
if isinstance(results, dict):
|
||||
if is_csv_format:
|
||||
# CSV输入: 合并所有指数为单个CSV文件
|
||||
self.save_merged_csv(results, output_prefix)
|
||||
else:
|
||||
# 图像输入: 将多个指数堆叠为多波段dat文件
|
||||
self.save_multiband_dat(results, output_prefix)
|
||||
elif isinstance(results, np.ndarray):
|
||||
# 保存堆叠的结果
|
||||
output_file = f"{output_prefix}_indices_stack"
|
||||
self.save_single_result(results, "multiple_indices", output_file)
|
||||
|
||||
def save_single_result(self, data, index_name, output_file):
|
||||
"""保存单个结果"""
|
||||
# 确定文件格式
|
||||
if output_file.endswith('.csv'):
|
||||
# 保存为CSV
|
||||
if len(data.shape) == 2:
|
||||
df = pd.DataFrame(data)
|
||||
else:
|
||||
# 3D数据展平
|
||||
flattened = data.reshape(-1, data.shape[-1]) if len(data.shape) == 3 else data.flatten()
|
||||
df = pd.DataFrame(flattened)
|
||||
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"已保存: {output_file} (CSV格式)")
|
||||
|
||||
else:
|
||||
# 保存为二进制文件 + 头文件
|
||||
data_file = output_file + '.dat'
|
||||
hdr_file = output_file + '.hdr'
|
||||
|
||||
# 保存二进制数据
|
||||
data.astype('float32').tofile(data_file)
|
||||
|
||||
# 创建ENVI头文件 - 符合ENVI标准格式
|
||||
if len(data.shape) == 1:
|
||||
# 1D数据 (CSV格式)
|
||||
lines = 1
|
||||
samples = data.shape[0]
|
||||
bands = 1
|
||||
elif len(data.shape) == 2:
|
||||
# 2D数据 (图像格式)
|
||||
lines = data.shape[0]
|
||||
samples = data.shape[1]
|
||||
bands = 1
|
||||
else:
|
||||
# 3D数据 (多波段图像)
|
||||
lines = data.shape[0]
|
||||
samples = data.shape[1]
|
||||
bands = data.shape[2]
|
||||
|
||||
header_content = f"""ENVI
|
||||
samples = {samples}
|
||||
lines = {lines}
|
||||
bands = {bands}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 4
|
||||
interleave = bsq
|
||||
byte order = 0
|
||||
wavelength units = Unknown
|
||||
data ignore value = 0
|
||||
"""
|
||||
|
||||
with open(hdr_file, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
print(f"已保存: {data_file} 和 {hdr_file} (ENVI格式)")
|
||||
print(f"数据形状: {data.shape}")
|
||||
|
||||
def save_merged_csv(self, results_dict, output_prefix):
|
||||
"""
|
||||
将多个光谱指数结果合并保存为单个CSV文件
|
||||
|
||||
参数:
|
||||
results_dict: 包含多个指数结果的字典
|
||||
output_prefix: 输出文件前缀
|
||||
"""
|
||||
# 创建输出文件名
|
||||
output_file = f"{output_prefix}_merged_indices.csv"
|
||||
|
||||
# 过滤掉None值的结果
|
||||
valid_results = {name: data for name, data in results_dict.items() if data is not None}
|
||||
|
||||
if not valid_results:
|
||||
print("没有有效的指数结果可保存")
|
||||
return
|
||||
|
||||
# 获取第一个有效结果的形状作为参考
|
||||
first_result = next(iter(valid_results.values()))
|
||||
if len(first_result.shape) != 1:
|
||||
print("合并CSV仅支持1D数据")
|
||||
return
|
||||
|
||||
# 创建合并的数据框
|
||||
merged_data = {}
|
||||
|
||||
# 添加每个指数的值
|
||||
for index_name, data in valid_results.items():
|
||||
merged_data[index_name] = data
|
||||
|
||||
# 创建DataFrame并保存
|
||||
df = pd.DataFrame(merged_data)
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"已合并保存 {len(valid_results)} 个指数到: {output_file}")
|
||||
print(f"总样本数: {len(df)}")
|
||||
|
||||
def save_multiband_dat(self, results_dict, output_prefix):
|
||||
"""
|
||||
将多个光谱指数结果保存为多波段的dat文件
|
||||
|
||||
参数:
|
||||
results_dict: 包含多个指数结果的字典
|
||||
output_prefix: 输出文件前缀
|
||||
"""
|
||||
# 过滤掉None值的结果
|
||||
valid_results = {name: data for name, data in results_dict.items() if data is not None}
|
||||
|
||||
if not valid_results:
|
||||
print("没有有效的指数结果可保存")
|
||||
return
|
||||
|
||||
# 移除单维度(如果存在)
|
||||
index_names = []
|
||||
index_data_list = []
|
||||
|
||||
for index_name, data in valid_results.items():
|
||||
index_names.append(index_name)
|
||||
# 如果数据有单维度(例如形状为(1066, 1148, 1)),则压缩它
|
||||
if data.ndim == 3 and data.shape[2] == 1:
|
||||
data = np.squeeze(data, axis=2)
|
||||
index_data_list.append(data)
|
||||
|
||||
# 检查所有数据是否都是二维的
|
||||
for i, data in enumerate(index_data_list):
|
||||
if data.ndim != 2:
|
||||
print(f"警告: 指数 {index_names[i]} 的形状为 {data.shape},不是二维数组")
|
||||
# 尝试展平为二维
|
||||
if data.ndim > 2:
|
||||
data = data.reshape(data.shape[0], data.shape[1])
|
||||
else:
|
||||
data = data.reshape(1, -1)
|
||||
index_data_list[i] = data
|
||||
|
||||
# 堆叠为多波段数据 (行, 列, 波段)
|
||||
stacked_data = np.stack(index_data_list, axis=-1)
|
||||
|
||||
# 保存为dat文件
|
||||
output_file = f"{output_prefix}_indices"
|
||||
data_file = output_file + '.dat'
|
||||
hdr_file = output_file + '.hdr'
|
||||
|
||||
# 保存二进制数据
|
||||
stacked_data.astype('float32').tofile(data_file)
|
||||
|
||||
# 创建ENVI头文件 - 符合ENVI标准格式
|
||||
lines = stacked_data.shape[0]
|
||||
samples = stacked_data.shape[1]
|
||||
bands = stacked_data.shape[2]
|
||||
|
||||
header_content = f"""ENVI
|
||||
samples = {samples}
|
||||
lines = {lines}
|
||||
bands = {bands}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 4
|
||||
interleave = bip
|
||||
byte order = 0
|
||||
wavelength units = Unknown
|
||||
data ignore value = 0
|
||||
band names = {{{', '.join(index_names)}}}
|
||||
"""
|
||||
|
||||
with open(hdr_file, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
print(f"已保存多波段数据: {data_file} 和 {hdr_file}")
|
||||
print(f"包含 {bands} 个指数: {', '.join(index_names)}")
|
||||
print(f"数据形状: {stacked_data.shape}")
|
||||
|
||||
def list_available_indices(self, category=None):
|
||||
"""
|
||||
列出所有可用的光谱指数
|
||||
|
||||
参数:
|
||||
category: 按类别过滤 (如 'Vegetation', 'Mineral')
|
||||
"""
|
||||
print("\n=== 可用光谱指数 ===")
|
||||
|
||||
indices_by_category = {}
|
||||
|
||||
for idx, info in self.formulas.items():
|
||||
categories = info['type']
|
||||
|
||||
for cat in categories:
|
||||
if cat not in indices_by_category:
|
||||
indices_by_category[cat] = []
|
||||
indices_by_category[cat].append((idx, info['name']))
|
||||
|
||||
# 显示所有类别或特定类别
|
||||
if category:
|
||||
if category in indices_by_category:
|
||||
print(f"\n{category} 指数:")
|
||||
for idx, name in sorted(indices_by_category[category]):
|
||||
print(f" {idx:10} - {name}")
|
||||
else:
|
||||
print(f"未找到类别: {category}")
|
||||
else:
|
||||
for cat in sorted(indices_by_category.keys()):
|
||||
print(f"\n{cat} ({len(indices_by_category[cat])}个指数):")
|
||||
for idx, name in sorted(indices_by_category[cat])[:10]: # 只显示前10个
|
||||
print(f" {idx:10} - {name}")
|
||||
if len(indices_by_category[cat]) > 10:
|
||||
print(f" ... 还有 {len(indices_by_category[cat]) - 10} 个")
|
||||
|
||||
def run_analysis_from_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
基于配置对象运行完整分析流程 - 推荐用于GUI对接
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 分析结果字典
|
||||
"""
|
||||
print("Starting spectral index analysis from configuration...")
|
||||
|
||||
# 1. 加载数据
|
||||
if not self.config.data.data_file_path:
|
||||
raise ValueError("Data file path must be specified in configuration")
|
||||
|
||||
print(f"Loading data from: {self.config.data.data_file_path}")
|
||||
self.load_hyperspectral_data(self.config.data.data_file_path)
|
||||
|
||||
# 2. 确定要计算的指数
|
||||
if self.config.indices.indices_to_calculate is None:
|
||||
# 计算所有可用指数
|
||||
indices_to_calculate = list(self.formulas.keys())
|
||||
print(f"Calculating all {len(indices_to_calculate)} available indices")
|
||||
else:
|
||||
# 计算指定的指数
|
||||
indices_to_calculate = self.config.indices.indices_to_calculate
|
||||
invalid_indices = [idx for idx in indices_to_calculate if idx not in self.formulas]
|
||||
if invalid_indices:
|
||||
print(f"Warning: The following indices do not exist: {invalid_indices}")
|
||||
indices_to_calculate = [idx for idx in indices_to_calculate if idx in self.formulas]
|
||||
print(f"Calculating {len(indices_to_calculate)} specified indices: {indices_to_calculate}")
|
||||
|
||||
# 3. 计算指数
|
||||
results = self.calculate_multiple_indices(indices_to_calculate, output_format='separate')
|
||||
|
||||
# 4. 保存结果
|
||||
if self.config.output.save_combined_results or self.config.output.save_individual_indices:
|
||||
self.save_results(results, 'config_results')
|
||||
|
||||
print("Analysis completed!")
|
||||
return results
|
||||
|
||||
total = len(self.formulas)
|
||||
print(f"\n总计: {total} 个光谱指数")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 命令行接口"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='高光谱光谱指数计算工具')
|
||||
parser.add_argument('input_file', help='输入的高光谱文件路径 (CSV, BIL, BSQ, BIP, DAT)')
|
||||
parser.add_argument('-i', '--indices', nargs='+', help='要计算的指数列表 (如 NDVI EVI)')
|
||||
parser.add_argument('-a', '--all', action='store_true', help='计算所有植被指数')
|
||||
parser.add_argument('-c', '--category', help='计算特定类别的所有指数 (Vegetation, Mineral等)')
|
||||
parser.add_argument('-o', '--output', default='output', help='输出文件前缀')
|
||||
parser.add_argument('-f', '--format', choices=['separate', 'stack'], default='separate',
|
||||
help='输出格式: separate(分开) 或 stack(堆叠)')
|
||||
parser.add_argument('-l', '--list', action='store_true', help='列出所有可用指数')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化计算器
|
||||
calculator = HyperspectralIndexCalculator()
|
||||
|
||||
# 如果只是列出指数
|
||||
if args.list:
|
||||
calculator.list_available_indices()
|
||||
return
|
||||
|
||||
# 加载数据
|
||||
print(f"加载数据: {args.input_file}")
|
||||
calculator.load_hyperspectral_data(args.input_file)
|
||||
|
||||
# 确定要计算的指数
|
||||
indices_to_calculate = []
|
||||
|
||||
if args.indices:
|
||||
indices_to_calculate = args.indices
|
||||
elif args.all:
|
||||
# 计算所有植被指数
|
||||
for idx, info in calculator.formulas.items():
|
||||
if 'Vegetation' in info['type']:
|
||||
indices_to_calculate.append(idx)
|
||||
elif args.category:
|
||||
# 计算特定类别的所有指数
|
||||
for idx, info in calculator.formulas.items():
|
||||
if args.category in info['type']:
|
||||
indices_to_calculate.append(idx)
|
||||
else:
|
||||
# 默认计算常用植被指数
|
||||
default_indices = ['NDVI', 'EVI', 'NDWI', 'NDII', 'MSI', 'GNDVI']
|
||||
indices_to_calculate = [idx for idx in default_indices if idx in calculator.formulas]
|
||||
|
||||
if not indices_to_calculate:
|
||||
print("错误: 未指定要计算的指数")
|
||||
calculator.list_available_indices()
|
||||
return
|
||||
|
||||
print(f"\n将计算 {len(indices_to_calculate)} 个指数:")
|
||||
for idx in indices_to_calculate[:10]:
|
||||
if idx in calculator.formulas:
|
||||
print(f" - {idx}: {calculator.formulas[idx]['name']}")
|
||||
if len(indices_to_calculate) > 10:
|
||||
print(f" ... 还有 {len(indices_to_calculate) - 10} 个")
|
||||
|
||||
# 计算指数
|
||||
results = calculator.calculate_multiple_indices(
|
||||
indices_to_calculate,
|
||||
output_format=args.format
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
calculator.save_results(results, args.output)
|
||||
|
||||
print("\n计算完成!")
|
||||
|
||||
|
||||
# ===== 主要示例 - 展示配置驱动和向后兼容两种方式 =====
|
||||
if __name__ == "__main__":
|
||||
print("="*60)
|
||||
print("Spectral Index Calculator - Configuration-Driven Interface")
|
||||
print("="*60)
|
||||
|
||||
# 方法1:配置驱动方式(推荐用于GUI对接)
|
||||
print("\n--- Method 1: Configuration-Driven (Recommended for GUI) ---")
|
||||
|
||||
# 创建配置对象
|
||||
config = SpectralIndexConfig.create_quick_analysis(
|
||||
data_file_path=r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr",
|
||||
indices_to_calculate=['NDVI', 'EVI', 'NDWI'] # 指定要计算的指数
|
||||
)
|
||||
# 可选:自定义配置
|
||||
config.output.output_dir = "config_results"
|
||||
config.output.save_individual_indices = False # 只保存合并结果
|
||||
|
||||
# 创建计算器并传入配置
|
||||
calculator = HyperspectralIndexCalculator(config)
|
||||
|
||||
# 运行配置驱动的分析
|
||||
results = calculator.run_analysis_from_config()
|
||||
print(f"Configuration-driven analysis completed! Calculated {len(results)} indices.")
|
||||
|
||||
# 方法2:向后兼容方式(传统参数传递)
|
||||
print("\n--- Method 2: Backward Compatible (Legacy Direct Usage) ---")
|
||||
|
||||
calculator2 = HyperspectralIndexCalculator() # 使用默认配置
|
||||
|
||||
# 使用传统方式加载数据和计算
|
||||
calculator2.load_hyperspectral_data(r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr")
|
||||
|
||||
# 计算多个指数
|
||||
results2 = calculator2.calculate_multiple_indices(['NDVI', 'EVI', 'NDWI'])
|
||||
print(f"Legacy analysis completed! Calculated {len(results2)} indices.")
|
||||
|
||||
# 保存结果
|
||||
calculator2.save_results(results2, 'legacy_results')
|
||||
|
||||
# 方法3:命令行方式(如果需要)
|
||||
print("\n--- Method 3: Command Line (if needed) ---")
|
||||
print("Run: python spectral_index.py input_file.dat -i NDVI EVI -o results")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Both configuration-driven and legacy methods are supported.")
|
||||
print("Configuration-driven is recommended for GUI integration.")
|
||||
print("="*60)
|
||||
|
||||
|
||||
846
supervize_cluster_method/supervize_cluster.py
Normal file
846
supervize_cluster_method/supervize_cluster.py
Normal file
@ -0,0 +1,846 @@
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
from scipy.spatial.distance import pdist, squareform
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import spectral as spy
|
||||
from spectral import envi
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from typing import Tuple, List, Dict, Optional, Union
|
||||
import warnings
|
||||
import xml.etree.ElementTree as ET
|
||||
from shapely.geometry import Polygon, Point
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class HyperspectralDistanceMetrics:
|
||||
"""高光谱数据距离度量类"""
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""余弦相似度"""
|
||||
dot_product = np.dot(spectrum1, spectrum2)
|
||||
norm1 = np.linalg.norm(spectrum1)
|
||||
norm2 = np.linalg.norm(spectrum2)
|
||||
return dot_product / (norm1 * norm2) if norm1 != 0 and norm2 != 0 else 0
|
||||
|
||||
@staticmethod
|
||||
def euclidean_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""欧氏距离"""
|
||||
return np.linalg.norm(spectrum1 - spectrum2)
|
||||
|
||||
@staticmethod
|
||||
def information_divergence(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""信息散度 (Kullback-Leibler divergence)"""
|
||||
# 添加小常数避免除零
|
||||
eps = 1e-10
|
||||
spectrum1 = spectrum1 + eps
|
||||
spectrum2 = spectrum2 + eps
|
||||
|
||||
# 归一化为概率分布
|
||||
p = spectrum1 / np.sum(spectrum1)
|
||||
q = spectrum2 / np.sum(spectrum2)
|
||||
|
||||
# 计算KL散度
|
||||
kl_pq = np.sum(p * np.log(p / q))
|
||||
kl_qp = np.sum(q * np.log(q / p))
|
||||
|
||||
return (kl_pq + kl_qp) / 2
|
||||
|
||||
@staticmethod
|
||||
def correlation_coefficient(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""相关系数"""
|
||||
return np.corrcoef(spectrum1, spectrum2)[0, 1]
|
||||
|
||||
@staticmethod
|
||||
def jeffreys_matusita_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""J-M距离 (Jeffreys-Matusita Distance)"""
|
||||
eps = 1e-10
|
||||
# 转换为概率分布
|
||||
p = (spectrum1 + eps) / np.sum(spectrum1 + eps)
|
||||
q = (spectrum2 + eps) / np.sum(spectrum2 + eps)
|
||||
|
||||
# Bhattacharyya系数
|
||||
bc = np.sum(np.sqrt(p * q))
|
||||
|
||||
# J-M距离 = sqrt(2 * (1 - bc))
|
||||
jm_distance = np.sqrt(2 * (1 - bc))
|
||||
return jm_distance
|
||||
|
||||
@staticmethod
|
||||
def spectral_angle_mapper(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""光谱角映射器 (Spectral Angle Mapper)"""
|
||||
cos_sim = HyperspectralDistanceMetrics.cosine_similarity(spectrum1, spectrum2)
|
||||
# 转换为角度(弧度)
|
||||
angle = np.arccos(np.clip(cos_sim, -1, 1))
|
||||
return angle
|
||||
|
||||
@staticmethod
|
||||
def sid_sa_combined(spectrum1: np.ndarray, spectrum2: np.ndarray,
|
||||
alpha: float = 0.5) -> float:
|
||||
"""SID_SA结合法 (Spectral Information Divergence + Spectral Angle Mapper)"""
|
||||
sid = HyperspectralDistanceMetrics.information_divergence(spectrum1, spectrum2)
|
||||
sam = HyperspectralDistanceMetrics.spectral_angle_mapper(spectrum1, spectrum2)
|
||||
|
||||
# 标准化并结合
|
||||
sid_norm = sid / (1 + sid) # 归一化到[0,1)
|
||||
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
|
||||
|
||||
return alpha * sid_norm + (1 - alpha) * sam_norm
|
||||
|
||||
# ===== 向量化距离计算函数 =====
|
||||
|
||||
@staticmethod
|
||||
def vectorized_cosine_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化余弦距离计算"""
|
||||
# X: (n_samples, n_features), centers: (n_clusters, n_features)
|
||||
# 返回: (n_samples, n_clusters)
|
||||
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
|
||||
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
|
||||
|
||||
# 避免除零
|
||||
X_norm = np.where(X_norm == 0, 1, X_norm)
|
||||
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
|
||||
# 转换为距离 (1 - 相似度)
|
||||
return 1 - np.clip(similarity, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def vectorized_euclidean_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化欧氏距离计算"""
|
||||
# 使用广播计算欧氏距离
|
||||
# X: (n_samples, n_features), centers: (n_clusters, n_features)
|
||||
# 返回: (n_samples, n_clusters)
|
||||
diff = X[:, np.newaxis, :] - centers[np.newaxis, :, :]
|
||||
return np.linalg.norm(diff, axis=2)
|
||||
|
||||
@staticmethod
|
||||
def vectorized_correlation_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化相关系数距离计算"""
|
||||
# 计算相关系数矩阵
|
||||
corr_matrix = np.corrcoef(X.T, centers.T)[:len(X), len(X):]
|
||||
# 转换为距离 (1 - |相关系数|)
|
||||
return 1 - np.abs(corr_matrix)
|
||||
|
||||
@staticmethod
|
||||
def vectorized_information_divergence(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化信息散度计算"""
|
||||
eps = 1e-10
|
||||
X_norm = X + eps
|
||||
centers_norm = centers + eps
|
||||
|
||||
# 归一化为概率分布
|
||||
X_prob = X_norm / np.sum(X_norm, axis=1, keepdims=True)
|
||||
centers_prob = centers_norm / np.sum(centers_norm, axis=1, keepdims=True)
|
||||
|
||||
# 计算KL散度
|
||||
# 使用广播: X_prob[:, np.newaxis, :] - centers_prob[np.newaxis, :, :]
|
||||
kl_pq = np.sum(X_prob[:, np.newaxis, :] * np.log(X_prob[:, np.newaxis, :] / centers_prob[np.newaxis, :, :]), axis=2)
|
||||
kl_qp = np.sum(centers_prob[np.newaxis, :, :] * np.log(centers_prob[np.newaxis, :, :] / X_prob[:, np.newaxis, :]), axis=2)
|
||||
|
||||
return (kl_pq + kl_qp) / 2
|
||||
|
||||
@staticmethod
|
||||
def vectorized_jeffreys_matusita_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化J-M距离计算"""
|
||||
eps = 1e-10
|
||||
|
||||
# 转换为概率分布
|
||||
X_prob = (X + eps) / np.sum(X + eps, axis=1, keepdims=True)
|
||||
centers_prob = (centers + eps) / np.sum(centers + eps, axis=1, keepdims=True)
|
||||
|
||||
# 计算Bhattacharyya系数
|
||||
bc = np.sum(np.sqrt(X_prob[:, np.newaxis, :] * centers_prob[np.newaxis, :, :]), axis=2)
|
||||
|
||||
# J-M距离
|
||||
return np.sqrt(2 * (1 - bc))
|
||||
|
||||
@staticmethod
|
||||
def vectorized_sid_sa_combined(X: np.ndarray, centers: np.ndarray, alpha: float = 0.5) -> np.ndarray:
|
||||
"""向量化SID_SA结合距离计算"""
|
||||
# 计算SID距离
|
||||
sid = HyperspectralDistanceMetrics.vectorized_information_divergence(X, centers)
|
||||
|
||||
# 计算SAM距离
|
||||
sam = HyperspectralDistanceMetrics.vectorized_spectral_angle(X, centers)
|
||||
|
||||
# 归一化
|
||||
sid_norm = sid / (1 + sid) # 归一化到[0,1)
|
||||
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
|
||||
|
||||
# 结合
|
||||
return alpha * sid_norm + (1 - alpha) * sam_norm
|
||||
|
||||
@staticmethod
|
||||
def vectorized_spectral_angle(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化光谱角距离计算"""
|
||||
# 计算余弦相似度
|
||||
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
|
||||
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
|
||||
|
||||
X_norm = np.where(X_norm == 0, 1, X_norm)
|
||||
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
|
||||
|
||||
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
|
||||
similarity = np.clip(similarity, -1, 1)
|
||||
|
||||
# 转换为角度
|
||||
return np.arccos(similarity)
|
||||
|
||||
|
||||
class HyperspectralClassification:
|
||||
"""高光谱图像监督分类器"""
|
||||
|
||||
def __init__(self, random_state: int = 42, distance_params: Dict[str, Dict] = None):
|
||||
self.random_state = random_state
|
||||
self.distance_metrics = HyperspectralDistanceMetrics()
|
||||
self.reference_spectra_ = None
|
||||
self.labels_ = None
|
||||
self.class_names_ = None
|
||||
|
||||
# 默认距离度量参数
|
||||
self.default_distance_params = {
|
||||
'cosine': {},
|
||||
'euclidean': {},
|
||||
'correlation': {},
|
||||
'information_divergence': {},
|
||||
'sam': {},
|
||||
'jm_distance': {'alpha': 0.5}, # Bhattacharyya系数权重
|
||||
'sid_sa': {'alpha': 0.5} # SID和SAM的权重
|
||||
}
|
||||
|
||||
# 更新用户提供的参数
|
||||
self.distance_params = self.default_distance_params.copy()
|
||||
if distance_params:
|
||||
self.distance_params.update(distance_params)
|
||||
|
||||
# 更新向量化距离函数,添加新的方法
|
||||
self.vectorized_distance_functions = {
|
||||
'cosine': self.distance_metrics.vectorized_cosine_distance,
|
||||
'euclidean': self.distance_metrics.vectorized_euclidean_distance,
|
||||
'correlation': self.distance_metrics.vectorized_correlation_distance,
|
||||
'information_divergence': self.distance_metrics.vectorized_information_divergence,
|
||||
'sam': self.distance_metrics.vectorized_spectral_angle,
|
||||
'jm_distance': self.distance_metrics.vectorized_jeffreys_matusita_distance, # 新增
|
||||
'sid_sa': self.distance_metrics.vectorized_sid_sa_combined # 新增
|
||||
}
|
||||
|
||||
def fit_predict_with_references(self, X: np.ndarray,
|
||||
reference_spectra: Dict[str, np.ndarray],
|
||||
method: str = 'euclidean') -> np.ndarray:
|
||||
"""
|
||||
使用参考光谱进行监督分类
|
||||
|
||||
Parameters:
|
||||
X: 高光谱数据,形状为 (n_samples, n_bands) 或 (rows, cols, n_bands)
|
||||
reference_spectra: 参考光谱字典,键为类别名,值为光谱 (bands,)
|
||||
method: 距离度量方法
|
||||
|
||||
Returns:
|
||||
分类结果,形状与X相同
|
||||
"""
|
||||
if method not in self.vectorized_distance_functions:
|
||||
raise ValueError(f"不支持的距离度量方法: {method}")
|
||||
|
||||
if not reference_spectra:
|
||||
raise ValueError("必须提供参考光谱")
|
||||
|
||||
print(f"DEBUG fit_predict_with_references: X.shape = {X.shape}")
|
||||
print(f"DEBUG fit_predict_with_references: reference_spectra keys = {list(reference_spectra.keys())}")
|
||||
for name, spectrum in reference_spectra.items():
|
||||
print(f"DEBUG fit_predict_with_references: {name}.shape = {spectrum.shape}")
|
||||
|
||||
# 处理3D输入 (图像数据)
|
||||
if X.ndim == 3:
|
||||
rows, cols, bands = X.shape
|
||||
X_reshaped = X.reshape(-1, bands)
|
||||
# 移除NaN和无穷大值
|
||||
valid_mask = np.isfinite(X_reshaped).all(axis=1)
|
||||
X_valid = X_reshaped[valid_mask]
|
||||
else:
|
||||
X_reshaped = X
|
||||
X_valid = X_reshaped
|
||||
rows, cols = None, None
|
||||
|
||||
if len(X_valid) == 0:
|
||||
raise ValueError("没有有效的光谱数据")
|
||||
|
||||
# 转换为参考光谱数组 - 确保是二维数组 (n_classes, n_bands)
|
||||
class_names = list(reference_spectra.keys())
|
||||
|
||||
# 首先检查并确保每个参考光谱都是一维的
|
||||
cleaned_reference_spectra = []
|
||||
for name in class_names:
|
||||
spectrum = reference_spectra[name]
|
||||
if spectrum.ndim > 1:
|
||||
spectrum = spectrum.flatten()
|
||||
cleaned_reference_spectra.append(spectrum)
|
||||
|
||||
reference_array = np.array(cleaned_reference_spectra) # 现在应该是 (n_classes, n_bands)
|
||||
print(f"DEBUG fit_predict_with_references: reference_array.shape = {reference_array.shape}")
|
||||
|
||||
# 进行分类
|
||||
labels = self._classify_pixels(X_valid, reference_array, method)
|
||||
|
||||
# 如果是3D输入,需要重塑回图像形状
|
||||
if X.ndim == 3:
|
||||
result = np.full((rows * cols,), -1, dtype=int)
|
||||
result[valid_mask] = labels
|
||||
result = result.reshape(rows, cols)
|
||||
else:
|
||||
result = labels
|
||||
|
||||
self.reference_spectra_ = reference_spectra
|
||||
self.class_names_ = class_names
|
||||
self.labels_ = result
|
||||
return result
|
||||
|
||||
def _classify_pixels(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
|
||||
"""
|
||||
使用参考光谱对像素进行分类
|
||||
|
||||
Parameters:
|
||||
X: 像素光谱数据 (n_samples, n_bands)
|
||||
reference_spectra: 参考光谱 (n_classes, n_bands)
|
||||
method: 距离度量方法
|
||||
|
||||
Returns:
|
||||
分类标签 (n_samples,)
|
||||
"""
|
||||
print(f"DEBUG: X.shape = {X.shape}")
|
||||
print(f"DEBUG: reference_spectra.shape = {reference_spectra.shape}")
|
||||
print(f"DEBUG: method = {method}")
|
||||
|
||||
|
||||
|
||||
# 检查是否有向量化距离函数
|
||||
if method not in self.vectorized_distance_functions or self.vectorized_distance_functions[method] is None:
|
||||
raise ValueError(f"不支持向量化计算的距离度量方法: {method}")
|
||||
|
||||
# 使用向量化距离计算
|
||||
distances = self.vectorized_distance_functions[method](X, reference_spectra)
|
||||
|
||||
# 为每个像素分配最近的参考光谱类别
|
||||
labels = np.argmin(distances, axis=1)
|
||||
|
||||
return labels
|
||||
|
||||
def _classify_with_complex_distance(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
|
||||
"""
|
||||
使用复杂距离度量进行分类(非向量化)
|
||||
"""
|
||||
print(f"使用 {method} 距离度量进行分类...")
|
||||
|
||||
n_samples = X.shape[0]
|
||||
n_classes = reference_spectra.shape[0]
|
||||
labels = np.zeros(n_samples, dtype=int)
|
||||
|
||||
# 为每个像素计算与所有参考光谱的距离
|
||||
for i in range(n_samples):
|
||||
pixel = X[i]
|
||||
min_distance = float('inf')
|
||||
best_class = 0
|
||||
|
||||
for j in range(n_classes):
|
||||
reference = reference_spectra[j]
|
||||
|
||||
if method == 'jm_distance':
|
||||
distance = self.distance_metrics.jeffreys_matusita_distance(pixel, reference)
|
||||
elif method == 'sid_sa':
|
||||
alpha = self.distance_params[method].get('alpha', 0.5)
|
||||
distance = self.distance_metrics.sid_sa_combined(pixel, reference, alpha)
|
||||
else:
|
||||
# 默认使用欧氏距离
|
||||
distance = self.distance_metrics.euclidean_distance(pixel, reference)
|
||||
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
best_class = j
|
||||
|
||||
labels[i] = best_class
|
||||
|
||||
return labels
|
||||
|
||||
def fit_predict_all_methods(self, X: np.ndarray, reference_spectra: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||
"""使用所有距离度量方法进行分类"""
|
||||
results = {}
|
||||
methods = list(self.vectorized_distance_functions.keys())
|
||||
|
||||
print("开始使用不同距离度量进行分类...")
|
||||
for method in methods:
|
||||
print(f"正在使用 {method} 距离度量...")
|
||||
try:
|
||||
results[method] = self.fit_predict_with_references(X, reference_spectra, method)
|
||||
print(f"✓ {method} 分类完成")
|
||||
except Exception as e:
|
||||
print(f"✗ {method} 分类失败: {e}")
|
||||
results[method] = None
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class HyperspectralImageProcessor:
|
||||
"""高光谱图像处理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.header = None
|
||||
self.wavelengths = None
|
||||
|
||||
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
|
||||
|
||||
try:
|
||||
# 读取ENVI文件
|
||||
img = envi.open(hdr_path)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 提取波长信息
|
||||
if 'wavelength' in header:
|
||||
wavelengths = np.array([float(w) for w in header['wavelength']])
|
||||
else:
|
||||
wavelengths = None
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.wavelengths = wavelengths
|
||||
|
||||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||||
if wavelengths is not None:
|
||||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载图像失败: {e}")
|
||||
|
||||
def parse_roi_xml(self, xml_path: str) -> Dict[str, List[Tuple[float, float]]]:
|
||||
"""
|
||||
解析ENVI ROI XML文件,提取每个区域的坐标
|
||||
|
||||
Parameters:
|
||||
xml_path: XML文件路径
|
||||
|
||||
Returns:
|
||||
字典,键为ROI名称,值为坐标列表 [(x1,y1), (x2,y2), ...]
|
||||
"""
|
||||
try:
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
roi_coordinates = {}
|
||||
|
||||
for region in root.findall('Region'):
|
||||
name = region.get('name')
|
||||
coordinates_text = None
|
||||
|
||||
# 查找Coordinates元素
|
||||
coord_elem = region.find('.//Coordinates')
|
||||
if coord_elem is not None and coord_elem.text:
|
||||
coordinates_text = coord_elem.text.strip()
|
||||
|
||||
if coordinates_text:
|
||||
# 解析坐标字符串
|
||||
coords = []
|
||||
parts = coordinates_text.split()
|
||||
# 坐标是成对的 (x y x y ...)
|
||||
for i in range(0, len(parts), 2):
|
||||
if i + 1 < len(parts):
|
||||
x = float(parts[i])
|
||||
y = float(parts[i + 1])
|
||||
coords.append((x, y))
|
||||
|
||||
roi_coordinates[name] = coords
|
||||
|
||||
return roi_coordinates
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"解析XML文件失败: {e}")
|
||||
|
||||
def extract_reference_spectra(self, data: np.ndarray,
|
||||
roi_coordinates: Dict[str, List[Tuple[float, float]]]) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
从图像中提取参考光谱
|
||||
|
||||
Parameters:
|
||||
data: 高光谱图像数据 (rows, cols, bands)
|
||||
roi_coordinates: ROI坐标字典
|
||||
|
||||
Returns:
|
||||
字典,键为ROI名称,值为平均光谱 (bands,)
|
||||
"""
|
||||
reference_spectra = {}
|
||||
rows, cols, bands = data.shape
|
||||
|
||||
for roi_name, coords in roi_coordinates.items():
|
||||
# 创建多边形
|
||||
polygon = Polygon(coords)
|
||||
|
||||
# 找到多边形内的所有像素
|
||||
pixels_in_roi = []
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
point = Point(j, i) # 注意:ENVI坐标系中x是列,y是行
|
||||
if polygon.contains(point):
|
||||
pixels_in_roi.append((i, j))
|
||||
|
||||
if len(pixels_in_roi) == 0:
|
||||
print(f"警告: ROI '{roi_name}' 中没有找到像素")
|
||||
continue
|
||||
|
||||
# 提取像素光谱
|
||||
spectra = []
|
||||
for i, j in pixels_in_roi:
|
||||
# 修正:确保提取的是一维光谱
|
||||
spectrum = data[i, j, :]
|
||||
# 确保是一维数组
|
||||
if spectrum.ndim > 1:
|
||||
spectrum = spectrum.flatten()
|
||||
|
||||
if np.isfinite(spectrum).all() and spectrum.size == bands: # 只使用有效光谱
|
||||
spectra.append(spectrum)
|
||||
|
||||
if len(spectra) == 0:
|
||||
print(f"警告: ROI '{roi_name}' 中没有有效的光谱数据")
|
||||
continue
|
||||
|
||||
# 计算平均光谱 - 确保结果是一维
|
||||
avg_spectrum = np.mean(spectra, axis=0)
|
||||
if avg_spectrum.ndim > 1:
|
||||
avg_spectrum = avg_spectrum.flatten()
|
||||
|
||||
reference_spectra[roi_name] = avg_spectrum
|
||||
|
||||
print(f"ROI '{roi_name}': 找到 {len(spectra)} 个有效像素, 光谱维度 {avg_spectrum.shape}")
|
||||
|
||||
return reference_spectra
|
||||
|
||||
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
|
||||
"""加载ENVI格式高光谱图像"""
|
||||
try:
|
||||
# 读取ENVI文件
|
||||
img = envi.open(hdr_path)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 提取波长信息
|
||||
if 'wavelength' in header:
|
||||
wavelengths = np.array([float(w) for w in header['wavelength']])
|
||||
else:
|
||||
wavelengths = None
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.wavelengths = wavelengths
|
||||
|
||||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||||
if wavelengths is not None:
|
||||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载图像失败: {e}")
|
||||
|
||||
def save_single_band_image(self, data: np.ndarray, output_path: str,
|
||||
method_name: str = "", original_header: Dict = None) -> None:
|
||||
"""保存单波段聚类结果图像为dat和hdr文件"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
# 确保输出文件扩展名为.dat
|
||||
if not output_path.lower().endswith('.dat'):
|
||||
output_path = output_path.rsplit('.', 1)[0] + '.dat'
|
||||
|
||||
# 将数据转换为合适的格式(聚类标签通常使用int32)
|
||||
if data.dtype != np.int32:
|
||||
data_to_save = data.astype(np.int32)
|
||||
else:
|
||||
data_to_save = data
|
||||
|
||||
# 保存为二进制dat文件
|
||||
data_to_save.tofile(output_path)
|
||||
|
||||
# 创建对应的hdr文件
|
||||
hdr_path = output_path.rsplit('.', 1)[0] + '.hdr'
|
||||
self._create_cluster_hdr_file(hdr_path, data_to_save.shape, method_name, original_header)
|
||||
|
||||
print(f"聚类结果已保存到:")
|
||||
print(f" 数据文件: {output_path}")
|
||||
print(f" 头文件: {hdr_path}")
|
||||
print(f"数据类型: {data_to_save.dtype}, 形状: {data_to_save.shape}")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"保存图像失败: {e}")
|
||||
|
||||
def _create_cluster_hdr_file(self, hdr_path: str, data_shape: Tuple,
|
||||
method_name: str, original_header: Dict = None) -> None:
|
||||
"""创建聚类结果的ENVI头文件"""
|
||||
try:
|
||||
# 从原始头文件获取基本信息
|
||||
if original_header is not None:
|
||||
# 复制原始头文件的关键信息
|
||||
samples = original_header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
|
||||
lines = original_header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
|
||||
bands = 1 # 聚类结果是单波段
|
||||
interleave = 'bip'
|
||||
data_type = 3 # ENVI数据类型: 3=int32, 4=float32
|
||||
byte_order = original_header.get('byte order', 0)
|
||||
wavelength_units = original_header.get('wavelength units', 'nm')
|
||||
else:
|
||||
# 默认值(适用于2D聚类结果)
|
||||
if len(data_shape) == 2:
|
||||
lines, samples = data_shape
|
||||
else:
|
||||
lines = data_shape[0]
|
||||
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
|
||||
bands = 1
|
||||
interleave = 'bsq'
|
||||
data_type = 3 # int32
|
||||
byte_order = 0
|
||||
wavelength_units = 'Unknown'
|
||||
|
||||
# 写入hdr文件
|
||||
with open(hdr_path, 'w') as f:
|
||||
f.write("ENVI\n")
|
||||
f.write("description = {\n")
|
||||
f.write(f" 聚类结果 - {method_name}\n")
|
||||
f.write(" 单波段聚类标签图像\n")
|
||||
f.write("}\n")
|
||||
f.write(f"samples = {samples}\n")
|
||||
f.write(f"lines = {lines}\n")
|
||||
f.write(f"bands = {bands}\n")
|
||||
f.write(f"header offset = 0\n")
|
||||
f.write(f"file type = ENVI Standard\n")
|
||||
f.write(f"data type = {data_type}\n")
|
||||
f.write(f"interleave = {interleave}\n")
|
||||
f.write(f"byte order = {byte_order}\n")
|
||||
|
||||
# 如果有波长信息,添加虚拟波长
|
||||
if original_header and 'wavelength' in original_header:
|
||||
f.write("wavelength = {\n")
|
||||
f.write(" 类别标签\n")
|
||||
f.write("}\n")
|
||||
f.write(f"wavelength units = {wavelength_units}\n")
|
||||
|
||||
# 添加类别信息注释
|
||||
f.write("band names = {\n")
|
||||
f.write(f" 聚类结果_{method_name}\n")
|
||||
f.write("}\n")
|
||||
|
||||
print(f"成功创建头文件: {hdr_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建头文件失败: {e}")
|
||||
|
||||
def visualize_clusters(self, cluster_results: Dict[str, np.ndarray],
|
||||
save_path: Optional[str] = None) -> None:
|
||||
"""可视化聚类结果"""
|
||||
n_methods = len(cluster_results)
|
||||
fig, axes = plt.subplots(2, (n_methods + 1) // 2, figsize=(15, 10))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, (method, result) in enumerate(cluster_results.items()):
|
||||
if result is not None and i < len(axes):
|
||||
im = axes[i].imshow(result, cmap='tab10')
|
||||
axes[i].set_title(f'{method.replace("_", " ").title()}')
|
||||
axes[i].axis('off')
|
||||
plt.colorbar(im, ax=axes[i], shrink=0.8)
|
||||
|
||||
# 隐藏多余的子图
|
||||
for j in range(i + 1, len(axes)):
|
||||
axes[j].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"可视化结果已保存到: {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:处理高光谱图像监督分类"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='高光谱图像监督分类分析')
|
||||
parser.add_argument('input_file', help='输入ENVI格式高光谱图像文件 (.hdr)')
|
||||
parser.add_argument('xml_file', help='ENVI ROI XML文件路径')
|
||||
parser.add_argument('--output_dir', '-o', default='output',
|
||||
help='输出目录 (默认: output)')
|
||||
parser.add_argument('--method', '-m', default='all',
|
||||
choices=['all', 'euclidean', 'cosine', 'correlation',
|
||||
'information_divergence', 'jm_distance', 'sid_sa'],
|
||||
help='分类距离度量方法 (默认: all,使用所有方法)')
|
||||
parser.add_argument('--distance_params', '-p', type=str, default=None,
|
||||
help='距离度量方法的超参数,JSON格式字符串,例如: {"jm_distance": {"alpha": 0.7}}')
|
||||
parser.add_argument('--visualize', '-v', action='store_true',
|
||||
help='是否生成可视化结果')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 解析距离参数
|
||||
distance_params = None
|
||||
if args.distance_params:
|
||||
import json
|
||||
try:
|
||||
distance_params = json.loads(args.distance_params)
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告: 无法解析距离参数 '{args.distance_params}',使用默认参数")
|
||||
distance_params = None
|
||||
|
||||
try:
|
||||
# 调用分类函数
|
||||
return run_hsi_classification(
|
||||
input_file=args.input_file,
|
||||
xml_file=args.xml_file,
|
||||
output_dir=args.output_dir,
|
||||
method=args.method,
|
||||
distance_params=distance_params,
|
||||
visualize=args.visualize
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
|
||||
|
||||
def run_hsi_classification(input_file, xml_file, output_dir='output', method='all',
|
||||
distance_params=None, visualize=False):
|
||||
"""
|
||||
执行高光谱图像监督分类分析
|
||||
|
||||
参数:
|
||||
input_file: 输入ENVI格式高光谱图像文件 (.hdr)
|
||||
xml_file: ENVI ROI XML文件路径
|
||||
output_dir: 输出目录 (默认: output)
|
||||
method: 分类距离度量方法 ('all' 或具体方法名,默认: 'all')
|
||||
distance_params: 距离度量方法的超参数字典 (默认: None)
|
||||
visualize: 是否生成可视化结果 (默认: False)
|
||||
|
||||
返回:
|
||||
成功返回0,失败返回1
|
||||
"""
|
||||
try:
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 初始化处理器
|
||||
processor = HyperspectralImageProcessor()
|
||||
|
||||
# 加载图像
|
||||
print(f"加载高光谱图像: {input_file}")
|
||||
data, header = processor.load_image(input_file)
|
||||
|
||||
# 解析XML文件并提取参考光谱
|
||||
print(f"\n解析ROI XML文件: {xml_file}")
|
||||
roi_coordinates = processor.parse_roi_xml(xml_file)
|
||||
|
||||
if not roi_coordinates:
|
||||
raise ValueError("XML文件中没有找到有效的ROI区域")
|
||||
|
||||
print(f"找到 {len(roi_coordinates)} 个ROI区域")
|
||||
for roi_name, coords in roi_coordinates.items():
|
||||
print(f" {roi_name}: {len(coords)} 个顶点")
|
||||
|
||||
# 从图像中提取参考光谱
|
||||
print(f"\n图像数据形状: {data.shape}")
|
||||
print("提取参考光谱...")
|
||||
reference_spectra = processor.extract_reference_spectra(data, roi_coordinates)
|
||||
|
||||
if not reference_spectra:
|
||||
raise ValueError("无法从ROI区域提取有效的参考光谱")
|
||||
|
||||
print(f"成功提取 {len(reference_spectra)} 个参考光谱:")
|
||||
for roi_name, spectrum in reference_spectra.items():
|
||||
print(f" {roi_name}: 形状 {spectrum.shape}, 数据类型 {spectrum.dtype}")
|
||||
|
||||
# 初始化分类器
|
||||
classifier = HyperspectralClassification(distance_params=distance_params)
|
||||
|
||||
# 根据指定方法进行分类
|
||||
print(f"\n开始分类分析 (类别数: {len(reference_spectra)}, 方法: {method})...")
|
||||
if method == 'all':
|
||||
results = classifier.fit_predict_all_methods(data, reference_spectra)
|
||||
else:
|
||||
# 使用指定方法
|
||||
result = classifier.fit_predict_with_references(data, reference_spectra, method)
|
||||
results = {method: result}
|
||||
|
||||
# 保存结果
|
||||
print("\n保存分类结果...")
|
||||
saved_files = []
|
||||
for method, result in results.items():
|
||||
if result is not None:
|
||||
output_file = os.path.join(output_dir, f'classification_{method}.dat')
|
||||
processor.save_single_band_image(result, output_file, method, header)
|
||||
saved_files.append((method, output_file))
|
||||
|
||||
# 输出统计信息
|
||||
print("\n=== 分类统计信息 ===")
|
||||
for method, result in results.items():
|
||||
if result is not None:
|
||||
unique_labels = np.unique(result[result >= 0])
|
||||
n_classes_found = len(unique_labels)
|
||||
print(f"\n{method}: 识别 {n_classes_found} 个类别")
|
||||
|
||||
# 显示每个类别的像素数量
|
||||
for class_idx, class_name in enumerate(classifier.class_names_):
|
||||
pixel_count = np.sum(result == class_idx)
|
||||
print(f" {class_name}: {pixel_count} 个像素")
|
||||
|
||||
print("\n✓ 分类分析完成!")
|
||||
print(f"输出目录: {output_dir}")
|
||||
print(f"生成文件: {len(saved_files)} 个")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
# 直接调用示例
|
||||
if __name__ == '__main__':
|
||||
# 示例1:使用所有分类方法
|
||||
result = run_hsi_classification(
|
||||
input_file=r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip.hdr", # 你的高光谱图像文件
|
||||
xml_file=r"E:\code\spectronon\single_classsfication\data\roi.xml", # ROI XML文件
|
||||
output_dir=r'E:\code\spectronon\single_classsfication\tsst', # 输出目录
|
||||
method='all', # 使用SAM方法
|
||||
distance_params=None, # 使用默认参数
|
||||
visualize=False # 是否可视化
|
||||
)
|
||||
|
||||
# 示例2:指定特定分类方法和超参数
|
||||
# result = run_hsi_classification(
|
||||
# input_file='your_image.hdr',
|
||||
# xml_file='roi_regions.xml',
|
||||
# method='jm_distance', # 只使用JM距离
|
||||
# distance_params={'jm_distance': {'alpha': 0.7}},
|
||||
# output_dir='output_jm'
|
||||
# )
|
||||
|
||||
# 示例3:使用所有方法进行比较
|
||||
# result = run_hsi_classification(
|
||||
# input_file='data/image.hdr',
|
||||
# xml_file='data/roi.xml',
|
||||
# method='all' # 使用所有距离度量方法
|
||||
# )
|
||||
|
||||
# 根据返回值判断是否成功
|
||||
if result == 0:
|
||||
print("程序执行成功!")
|
||||
else:
|
||||
print("程序执行失败!")
|
||||
Reference in New Issue
Block a user