1108 lines
40 KiB
Python
1108 lines
40 KiB
Python
import numpy as np
|
||
import pandas as pd
|
||
import os
|
||
import re
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Optional, Dict, Any, Union
|
||
from dataclasses import dataclass, field
|
||
import warnings
|
||
|
||
try:
|
||
import spectral
|
||
SPECTRAL_AVAILABLE = True
|
||
except ImportError:
|
||
SPECTRAL_AVAILABLE = False
|
||
print("警告: spectral库不可用,将使用内置方法读取数据")
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class DataConfig:
|
||
"""数据配置类"""
|
||
data_file_path: str = ""
|
||
data_format: str = "auto" # 'csv', 'envi', 'auto'
|
||
wavelength_column: Optional[int] = None # CSV格式中波长列的索引
|
||
label_column: Optional[str] = None # CSV格式中的标签列名
|
||
spectral_start: Optional[str] = None # 光谱数据起始列名或索引
|
||
spectral_end: Optional[str] = None # 光谱数据结束列名或索引
|
||
|
||
|
||
@dataclass
|
||
class IndexConfig:
|
||
"""光谱指数配置类"""
|
||
spectral_index_csv: str = r"E:\code\spectronon\spectral_index.csv"
|
||
formula_csv: str = r"E:\code\spectronon\famula.csv"
|
||
indices_to_calculate: Optional[list] = None # 要计算的指数列表,None表示全部
|
||
|
||
|
||
@dataclass
|
||
class OutputConfig:
|
||
"""输出配置类"""
|
||
output_dir: str = "results"
|
||
save_individual_indices: bool = True
|
||
save_combined_results: bool = True
|
||
output_format: str = "csv" # 'csv', 'excel', 'both'
|
||
save_png_visualization: bool = False # 是否保存PNG可视化
|
||
png_filename: str = "spectral_indices_overview.png"
|
||
|
||
|
||
@dataclass
|
||
class SpectralIndexConfig:
|
||
"""光谱指数计算完整配置类 - 为GUI对接设计的标准化接口"""
|
||
data: DataConfig = field(default_factory=DataConfig)
|
||
indices: IndexConfig = field(default_factory=IndexConfig)
|
||
output: OutputConfig = field(default_factory=OutputConfig)
|
||
|
||
def __post_init__(self):
|
||
"""参数校验和默认值设置"""
|
||
self._validate_parameters()
|
||
|
||
def _validate_parameters(self):
|
||
"""参数校验"""
|
||
# 数据参数校验
|
||
if self.data.data_format not in ['csv', 'envi', 'auto']:
|
||
raise ValueError("Data format must be 'csv', 'envi', or 'auto'")
|
||
|
||
# 输出参数校验
|
||
if self.output.output_format not in ['csv', 'excel', 'both']:
|
||
raise ValueError("Output format must be 'csv', 'excel', or 'both'")
|
||
|
||
# 文件路径校验
|
||
if not os.path.exists(self.indices.spectral_index_csv):
|
||
print(f"Warning: Spectral index CSV file not found: {self.indices.spectral_index_csv}")
|
||
if not os.path.exists(self.indices.formula_csv):
|
||
print(f"Warning: Formula CSV file not found: {self.indices.formula_csv}")
|
||
|
||
@classmethod
|
||
def create_default(cls) -> 'SpectralIndexConfig':
|
||
"""创建默认配置"""
|
||
return cls()
|
||
|
||
@classmethod
|
||
def create_quick_analysis(cls, data_file_path: str, indices_to_calculate: Optional[list] = None) -> 'SpectralIndexConfig':
|
||
"""创建快速分析配置"""
|
||
config = cls()
|
||
config.data.data_file_path = data_file_path
|
||
if indices_to_calculate:
|
||
config.indices.indices_to_calculate = indices_to_calculate
|
||
config.output.save_individual_indices = False # 快速分析不保存单个指数
|
||
return config
|
||
|
||
|
||
class HyperspectralIndexCalculator:
|
||
"""
|
||
高光谱指数计算器 - 支持GUI对接的标准化接口
|
||
支持多种光谱指数的自动计算
|
||
"""
|
||
|
||
def __init__(self, config: Optional[SpectralIndexConfig] = None,
|
||
spectral_index_csv=None, formula_csv=None):
|
||
"""
|
||
初始化光谱指数计算器
|
||
|
||
Parameters:
|
||
config (SpectralIndexConfig, optional): 配置对象,如果为None则使用默认配置
|
||
spectral_index_csv (str, optional): 波段定义文件路径(向后兼容)
|
||
formula_csv (str, optional): 光谱指数公式文件路径(向后兼容)
|
||
"""
|
||
# 处理向后兼容性
|
||
if config is None and (spectral_index_csv is not None or formula_csv is not None):
|
||
# 使用传统参数方式
|
||
config = SpectralIndexConfig()
|
||
if spectral_index_csv is not None:
|
||
config.indices.spectral_index_csv = spectral_index_csv
|
||
if formula_csv is not None:
|
||
config.indices.formula_csv = formula_csv
|
||
|
||
self.config = config or SpectralIndexConfig()
|
||
self._validate_config()
|
||
|
||
# 加载波段定义
|
||
self.band_info = self.load_band_info(self.config.indices.spectral_index_csv)
|
||
|
||
# 加载光谱指数公式
|
||
self.formulas = self.load_formulas(self.config.indices.formula_csv)
|
||
|
||
# 存储当前加载的数据
|
||
self.data = None
|
||
self.wavelengths = None
|
||
self.data_shape = None
|
||
self.band_mapping = None
|
||
|
||
def update_config(self, config: SpectralIndexConfig):
|
||
"""
|
||
更新配置 - 为GUI动态配置预留接口
|
||
|
||
Parameters:
|
||
config (SpectralIndexConfig): 新的配置对象
|
||
"""
|
||
self.config = config
|
||
self._validate_config()
|
||
|
||
def _validate_config(self):
|
||
"""配置校验"""
|
||
try:
|
||
self.config._validate_parameters()
|
||
except ValueError as e:
|
||
raise ValueError(f"Configuration validation failed: {e}")
|
||
|
||
def load_band_info(self, csv_path):
|
||
"""加载波段定义信息"""
|
||
df = pd.read_csv(csv_path)
|
||
band_info = {}
|
||
|
||
for _, row in df.iterrows():
|
||
name = row['name']
|
||
band_info[name] = {
|
||
'min': float(row['min']),
|
||
'center': float(row['center']),
|
||
'max': float(row['max'])
|
||
}
|
||
|
||
# 为常用波段名称创建别名映射
|
||
if name in ['nir', 'red', 'green', 'blue', 'swir1', 'swir2']:
|
||
band_info[name.upper()] = band_info[name]
|
||
|
||
return band_info
|
||
|
||
def load_formulas(self, csv_path):
|
||
"""加载光谱指数公式"""
|
||
df = pd.read_csv(csv_path)
|
||
formulas = {}
|
||
|
||
for _, row in df.iterrows():
|
||
index = row['Index']
|
||
formulas[index] = {
|
||
'name': row['Name'],
|
||
'type': eval(row['Type']), # 将字符串列表转换为列表
|
||
'equation': row['Equation'],
|
||
'bands': eval(row['Bands']) # 将字符串列表转换为列表
|
||
}
|
||
return formulas
|
||
|
||
def load_hyperspectral_data(self, file_path=None):
|
||
"""
|
||
加载高光谱数据文件
|
||
|
||
支持格式:
|
||
- CSV: 第一行为波长,后续为数据
|
||
- ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (使用spectral库)
|
||
|
||
Args:
|
||
file_path: 文件路径,如果为None则从配置中获取
|
||
"""
|
||
if file_path is None:
|
||
file_path = self.config.data.data_file_path
|
||
|
||
file_path = Path(file_path)
|
||
suffix = file_path.suffix.lower()
|
||
|
||
if suffix == '.csv':
|
||
return self.load_csv_data(file_path)
|
||
else:
|
||
# 使用spectral库
|
||
if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']:
|
||
return self.load_with_spectral(file_path)
|
||
else:
|
||
raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件,并确保已安装spectral库。")
|
||
|
||
def load_csv_data(self, file_path):
|
||
"""加载CSV格式的高光谱数据"""
|
||
df = pd.read_csv(file_path)
|
||
|
||
# 处理配置中的参数
|
||
label_column = self.config.data.label_column
|
||
spectral_start = self.config.data.spectral_start
|
||
spectral_end = self.config.data.spectral_end
|
||
|
||
# 确定光谱数据列
|
||
if spectral_start is not None and spectral_end is not None:
|
||
# 使用指定的列范围
|
||
try:
|
||
start_idx = int(spectral_start) if spectral_start.isdigit() else df.columns.get_loc(spectral_start)
|
||
end_idx = int(spectral_end) if spectral_end.isdigit() else df.columns.get_loc(spectral_end) + 1
|
||
spectral_columns = df.columns[start_idx:end_idx]
|
||
except (ValueError, KeyError):
|
||
raise ValueError(f"无法找到指定的光谱列范围: {spectral_start} 到 {spectral_end}")
|
||
else:
|
||
# 自动检测:排除标签列的所有列
|
||
if label_column and label_column in df.columns:
|
||
spectral_columns = df.columns.drop(label_column)
|
||
else:
|
||
# 如果没有指定标签列,假设所有列都是光谱数据
|
||
spectral_columns = df.columns
|
||
|
||
# 提取光谱数据
|
||
spectral_data = df[spectral_columns]
|
||
|
||
# 检查是否有标签列
|
||
if label_column and label_column in df.columns:
|
||
self.labels = df[label_column].values
|
||
print(f"检测到标签列 '{label_column}',包含 {len(self.labels)} 个样本")
|
||
else:
|
||
self.labels = None
|
||
|
||
# 尝试将列名转换为波长
|
||
try:
|
||
self.wavelengths = spectral_columns.astype(float).values
|
||
print(f"成功将列名转换为波长")
|
||
except (ValueError, TypeError):
|
||
# 如果列名不是数字,创建默认波长
|
||
self.wavelengths = np.arange(len(spectral_columns), dtype=float)
|
||
print(f"列名不是数字波长,使用默认波长索引 0-{len(spectral_columns)-1}")
|
||
|
||
self.data = spectral_data.values
|
||
self.data_shape = self.data.shape
|
||
|
||
print(f"CSV数据加载成功: {self.data_shape}")
|
||
print(f"波段数量: {len(self.wavelengths)}")
|
||
if self.wavelengths is not None:
|
||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||
|
||
return self.data
|
||
|
||
def load_with_spectral(self, file_path):
|
||
"""
|
||
使用spectral库加载高光谱数据
|
||
支持ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件
|
||
"""
|
||
original_path = Path(file_path)
|
||
current_path = original_path
|
||
|
||
# 如果是.hdr文件,找到对应的数据文件
|
||
if current_path.suffix.lower() == '.hdr':
|
||
data_file = current_path.with_suffix('')
|
||
if not data_file.exists():
|
||
# 尝试其他常见扩展名
|
||
for ext in ['.dat', '.bil', '.bsq', '.bip']:
|
||
candidate = current_path.with_suffix(ext)
|
||
if candidate.exists():
|
||
data_file = candidate
|
||
break
|
||
current_path = data_file
|
||
|
||
print(f"使用spectral库加载: {current_path}")
|
||
|
||
# 使用spectral打开图像
|
||
img = spectral.open_image(file_path)
|
||
|
||
# 读取所有波段的数据
|
||
# spectral默认返回(行, 列, 波段)的numpy数组
|
||
self.data = img.load()
|
||
|
||
# 获取波长信息
|
||
if hasattr(img, 'metadata') and 'wavelength' in img.metadata:
|
||
wavelength_data = img.metadata['wavelength']
|
||
try:
|
||
# 处理波长数据,可能需要转换为float数组
|
||
if isinstance(wavelength_data, str):
|
||
# 如果是字符串,尝试解析
|
||
wavelength_data = wavelength_data.strip('{}[]')
|
||
if ',' in wavelength_data:
|
||
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()])
|
||
else:
|
||
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()])
|
||
elif isinstance(wavelength_data, list):
|
||
self.wavelengths = np.array([float(w) for w in wavelength_data])
|
||
else:
|
||
self.wavelengths = np.array(wavelength_data, dtype=float)
|
||
|
||
print(f"spectral解析波长: {len(self.wavelengths)} 个波段")
|
||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||
except Exception as e:
|
||
print(f"波长解析失败,使用默认值: {e}")
|
||
self.wavelengths = np.arange(self.data.shape[2])
|
||
else:
|
||
# 如果没有波长信息,创建默认值
|
||
self.wavelengths = np.arange(self.data.shape[2])
|
||
print(f"警告: spectral未找到波长信息,使用默认值: 0-{self.data.shape[2]-1}")
|
||
|
||
self.data_shape = self.data.shape
|
||
|
||
print(f"spectral数据加载成功: {self.data_shape}")
|
||
print(f"数据类型: {self.data.dtype}")
|
||
print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]")
|
||
|
||
return self.data
|
||
|
||
def find_band_index(self, band_name, tolerance=5):
|
||
"""
|
||
根据波段名称找到对应的波段索引
|
||
|
||
参数:
|
||
band_name: 波段名称 (如 'w550', 'nir', 'red')
|
||
tolerance: 容差范围 (nm)
|
||
|
||
返回:
|
||
波段索引 (从0开始)
|
||
"""
|
||
if band_name not in self.band_info:
|
||
# 尝试匹配标准波段名称
|
||
std_names = {
|
||
'nir': 'w860', 'red': 'w680', 'green': 'w550',
|
||
'blue': 'w470', 'swir1': 'w1650', 'swir2': 'w2200',
|
||
'tm1': 'w485', 'tm2': 'w569', 'tm3': 'w660',
|
||
'tm4': 'w833', 'tm5': 'w1676', 'tm7': 'w2223'
|
||
}
|
||
|
||
if band_name in std_names:
|
||
band_name = std_names[band_name]
|
||
else:
|
||
raise ValueError(f"未定义的波段名称: {band_name}")
|
||
|
||
target_info = self.band_info[band_name]
|
||
target_center = target_info['center']
|
||
target_min = target_info['min']
|
||
target_max = target_info['max']
|
||
|
||
# 检查当前数据的波长范围是否包含目标波段
|
||
data_min = self.wavelengths[0]
|
||
data_max = self.wavelengths[-1]
|
||
|
||
if target_min < data_min or target_max > data_max:
|
||
raise ValueError(
|
||
f"波段 {band_name} 超出数据波长范围: "
|
||
f"目标波段 [{target_min:.1f}-{target_max:.1f}nm], "
|
||
f"数据范围 [{data_min:.1f}-{data_max:.1f}nm]"
|
||
)
|
||
|
||
# 找到最接近的波长
|
||
diffs = np.abs(self.wavelengths - target_center)
|
||
min_idx = np.argmin(diffs)
|
||
|
||
if diffs[min_idx] > tolerance:
|
||
print(
|
||
f"警告: 波段 {band_name} (目标: {target_center}nm) 匹配到 {self.wavelengths[min_idx]:.1f}nm (差异: {diffs[min_idx]:.1f}nm)")
|
||
|
||
return min_idx
|
||
|
||
def calculate_index(self, index_name, output_type='image'):
|
||
"""
|
||
计算指定光谱指数
|
||
|
||
参数:
|
||
index_name: 光谱指数简称 (如 'NDVI', 'EVI')
|
||
output_type: 输出类型 ('image'或'values')
|
||
|
||
返回:
|
||
光谱指数图像或数值
|
||
"""
|
||
if index_name not in self.formulas:
|
||
available = list(self.formulas.keys())
|
||
raise ValueError(f"未找到指数: {index_name}。可用指数: {available}")
|
||
|
||
formula_info = self.formulas[index_name]
|
||
equation = formula_info['equation']
|
||
required_bands = formula_info['bands']
|
||
|
||
print(f"计算指数: {formula_info['name']} ({index_name})")
|
||
print(f"所需波段: {required_bands}")
|
||
print(f"公式: {equation}")
|
||
|
||
# 检查公式中是否包含临时变量
|
||
if ';' in equation:
|
||
# 分离主公式和临时变量定义
|
||
parts = equation.split(';')
|
||
main_eq = parts[0].strip()
|
||
temp_vars = parts[1:]
|
||
|
||
# 创建局部变量字典
|
||
local_vars = {}
|
||
|
||
# 先计算临时变量
|
||
for temp_eq in temp_vars:
|
||
if '=' in temp_eq:
|
||
var_name, expr = temp_eq.split('=', 1)
|
||
var_name = var_name.strip()
|
||
expr = expr.strip()
|
||
|
||
# 计算临时变量
|
||
temp_value = self.evaluate_expression(expr, required_bands, local_vars)
|
||
local_vars[var_name] = temp_value
|
||
else:
|
||
main_eq = equation.strip()
|
||
local_vars = {}
|
||
|
||
# 计算主公式
|
||
index_result = self.evaluate_expression(main_eq, required_bands, local_vars)
|
||
|
||
# 根据输出类型返回结果
|
||
if output_type == 'image':
|
||
return index_result
|
||
elif output_type == 'values':
|
||
# 展平为一维数组
|
||
return index_result.flatten()
|
||
else:
|
||
return index_result
|
||
|
||
def evaluate_expression(self, expression, required_bands, local_vars):
|
||
"""
|
||
评估表达式
|
||
|
||
参数:
|
||
expression: 表达式字符串
|
||
required_bands: 所需波段列表
|
||
local_vars: 局部变量字典
|
||
"""
|
||
# 提取所有波段变量
|
||
pattern = r'\b(w\d+|tm\d+|nir|red|green|blue|swir1|swir2|thermal)\b'
|
||
band_vars = re.findall(pattern, expression)
|
||
|
||
# 创建波段数据字典
|
||
band_data = {}
|
||
for band_var in set(band_vars):
|
||
# 检查是否已在局部变量中
|
||
if band_var in local_vars:
|
||
continue
|
||
|
||
# 查找波段索引
|
||
band_idx = self.find_band_index(band_var)
|
||
|
||
# 获取波段数据
|
||
if len(self.data_shape) == 2: # CSV格式
|
||
band_data[band_var] = self.data[:, band_idx]
|
||
else: # 图像格式
|
||
band_data[band_var] = self.data[:, :, band_idx]
|
||
|
||
# 合并局部变量
|
||
all_vars = {**band_data, **local_vars}
|
||
|
||
# 添加数学函数
|
||
math_funcs = {
|
||
'sqrt': np.sqrt,
|
||
'log': np.log,
|
||
'log10': np.log10,
|
||
'exp': np.exp,
|
||
'sin': np.sin,
|
||
'cos': np.cos,
|
||
'tan': np.tan,
|
||
'abs': np.abs,
|
||
'pow': np.power
|
||
}
|
||
|
||
# 安全评估表达式
|
||
try:
|
||
# 替换表达式中的^为**
|
||
expression = expression.replace('^', '**')
|
||
|
||
# 创建安全环境
|
||
safe_env = {**all_vars, **math_funcs}
|
||
|
||
# 添加numpy函数
|
||
safe_env['np'] = np
|
||
|
||
# 执行表达式
|
||
result = eval(expression, {"__builtins__": {}}, safe_env)
|
||
|
||
# 处理可能的异常值
|
||
if isinstance(result, np.ndarray):
|
||
result = np.nan_to_num(result, nan=0.0, posinf=1.0, neginf=-1.0)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"表达式评估失败: {expression}\n错误: {str(e)}")
|
||
|
||
def calculate_multiple_indices(self, index_list, output_format='separate'):
|
||
"""
|
||
计算多个光谱指数
|
||
|
||
参数:
|
||
index_list: 指数名称列表
|
||
output_format: 输出格式 ('separate'或'stack')
|
||
|
||
返回:
|
||
光谱指数结果
|
||
"""
|
||
results = {}
|
||
|
||
for index_name in index_list:
|
||
try:
|
||
result = self.calculate_index(index_name, output_type='image')
|
||
results[index_name] = result
|
||
print(f"✓ {index_name}: 计算完成 (形状: {result.shape})")
|
||
except Exception as e:
|
||
print(f"✗ {index_name}: 计算失败 - {str(e)}")
|
||
results[index_name] = None
|
||
|
||
if output_format == 'stack' and results:
|
||
# 将所有指数堆叠成一个多波段图像
|
||
valid_results = [r for r in results.values() if r is not None]
|
||
if valid_results:
|
||
return np.stack(valid_results, axis=-1)
|
||
|
||
return results
|
||
|
||
def calculate_all_indices(self):
|
||
"""
|
||
计算所有可用的光谱指数
|
||
|
||
返回:
|
||
dict: 包含所有指数结果的字典
|
||
"""
|
||
if self.config.indices.indices_to_calculate:
|
||
index_list = self.config.indices.indices_to_calculate
|
||
else:
|
||
index_list = list(self.formulas.keys())
|
||
|
||
print(f"开始计算 {len(index_list)} 个光谱指数...")
|
||
return self.calculate_multiple_indices(index_list, output_format='separate')
|
||
|
||
def create_indices_visualization(self, results=None, save_path=None):
|
||
"""
|
||
创建包含所有光谱指数的PNG可视化
|
||
|
||
参数:
|
||
results: 指数计算结果字典,如果为None则自动计算
|
||
save_path: 保存路径,如果为None则使用配置中的路径
|
||
|
||
返回:
|
||
matplotlib.figure.Figure or None: 可视化图形对象,如果没有有效结果则返回None
|
||
"""
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.gridspec as gridspec
|
||
from matplotlib import cm
|
||
|
||
# 如果没有提供结果,计算所有指数
|
||
if results is None:
|
||
results = self.calculate_all_indices()
|
||
|
||
# 过滤掉计算失败的指数
|
||
valid_results = {k: v for k, v in results.items() if v is not None}
|
||
|
||
if not valid_results:
|
||
print("警告: 没有有效的指数计算结果可以可视化")
|
||
return None
|
||
|
||
n_indices = len(valid_results)
|
||
print(f"为 {n_indices} 个有效指数创建可视化...")
|
||
|
||
# 计算网格布局
|
||
if n_indices <= 4:
|
||
nrows, ncols = 2, 2
|
||
elif n_indices <= 9:
|
||
nrows, ncols = 3, 3
|
||
elif n_indices <= 16:
|
||
nrows, ncols = 4, 4
|
||
else:
|
||
# 对于更多指数,使用更大的网格
|
||
nrows = int(np.ceil(np.sqrt(n_indices)))
|
||
ncols = int(np.ceil(n_indices / nrows))
|
||
|
||
# 创建图形
|
||
fig = plt.figure(figsize=(ncols * 4, nrows * 4))
|
||
gs = gridspec.GridSpec(nrows, ncols, figure=fig,
|
||
hspace=0.3, wspace=0.3,
|
||
top=0.95, bottom=0.05, left=0.05, right=0.95)
|
||
|
||
# 选择颜色映射
|
||
cmap = cm.get_cmap('RdYlBu_r')
|
||
|
||
# 为每个指数创建子图
|
||
plotted_count = 0
|
||
for i, (index_name, index_data) in enumerate(valid_results.items()):
|
||
if plotted_count >= nrows * ncols:
|
||
print(f"警告: 只显示前 {nrows*ncols} 个指数,共有 {n_indices} 个有效指数")
|
||
break
|
||
|
||
row = plotted_count // ncols
|
||
col = plotted_count % ncols
|
||
|
||
try:
|
||
ax = fig.add_subplot(gs[row, col])
|
||
|
||
# 处理不同维度的数据
|
||
if len(index_data.shape) == 2:
|
||
# 2D图像数据
|
||
im = ax.imshow(index_data, cmap=cmap, aspect='equal')
|
||
plt.colorbar(im, ax=ax, shrink=0.8)
|
||
elif len(index_data.shape) == 3 and index_data.shape[2] == 1:
|
||
# 单波段3D数据
|
||
im = ax.imshow(index_data[:, :, 0], cmap=cmap, aspect='equal')
|
||
plt.colorbar(im, ax=ax, shrink=0.8)
|
||
else:
|
||
# 其他格式的数据,尝试展平显示
|
||
if len(index_data.shape) > 2:
|
||
index_data = index_data.reshape(index_data.shape[0], -1)
|
||
|
||
# 显示为热力图
|
||
im = ax.imshow(index_data, cmap=cmap, aspect='auto')
|
||
plt.colorbar(im, ax=ax, shrink=0.8)
|
||
|
||
# 设置标题
|
||
formula_info = self.formulas.get(index_name, {})
|
||
full_name = formula_info.get('name', index_name)
|
||
ax.set_title(f'{full_name}\n({index_name})', fontsize=10, fontweight='bold')
|
||
|
||
# 移除坐标轴标签以节省空间
|
||
ax.set_xticks([])
|
||
ax.set_yticks([])
|
||
|
||
plotted_count += 1
|
||
|
||
except Exception as e:
|
||
print(f"创建指数 {index_name} 的子图失败: {e}")
|
||
continue
|
||
|
||
# 如果没有成功创建任何子图,返回None
|
||
if plotted_count == 0:
|
||
plt.close(fig)
|
||
print("警告: 未能创建任何有效的可视化子图")
|
||
return None
|
||
|
||
# 设置总标题
|
||
fig.suptitle(f'光谱指数可视化总览\n成功显示 {plotted_count}/{n_indices} 个指数',
|
||
fontsize=14, fontweight='bold', y=0.98)
|
||
|
||
# 调整布局
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
if save_path is None:
|
||
save_path = os.path.join(self.config.output.output_dir, self.config.output.png_filename)
|
||
|
||
try:
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
|
||
# 保存为高分辨率PNG
|
||
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f"光谱指数可视化已保存到: {save_path}")
|
||
except Exception as e:
|
||
print(f"保存可视化图像失败: {e}")
|
||
plt.close(fig)
|
||
return None
|
||
|
||
return fig
|
||
|
||
def batch_calculate_and_visualize(self, save_png=True):
|
||
"""
|
||
批量计算所有光谱指数并生成PNG可视化
|
||
|
||
参数:
|
||
save_png: 是否保存PNG可视化
|
||
|
||
返回:
|
||
tuple: (results_dict, figure_object)
|
||
"""
|
||
print("=" * 60)
|
||
print("开始批量光谱指数计算和可视化")
|
||
print("=" * 60)
|
||
|
||
# 确保数据已经加载
|
||
if self.data is None:
|
||
if hasattr(self.config, 'data') and self.config.data.data_file_path:
|
||
print(f"加载数据文件: {self.config.data.data_file_path}")
|
||
self.load_hyperspectral_data()
|
||
else:
|
||
raise ValueError("未指定数据文件路径,无法进行批量计算")
|
||
|
||
# 计算所有指数
|
||
results = self.calculate_all_indices()
|
||
|
||
# 统计计算结果
|
||
successful_results = {k: v for k, v in results.items() if v is not None}
|
||
failed_results = {k: v for k, v in results.items() if v is None}
|
||
|
||
print(f"计算完成: {len(successful_results)}/{len(results)} 个指数成功")
|
||
|
||
# 生成可视化(只使用成功的计算结果)
|
||
fig = None
|
||
if save_png and self.config.output.save_png_visualization:
|
||
if len(successful_results) > 0:
|
||
try:
|
||
fig = self.create_indices_visualization(successful_results)
|
||
print(f"可视化已保存到: {os.path.join(self.config.output.output_dir, self.config.output.png_filename)}")
|
||
except Exception as e:
|
||
print(f"生成可视化失败: {e}")
|
||
fig = None
|
||
else:
|
||
print("警告: 没有成功的指数计算结果,跳过可视化生成")
|
||
|
||
# 保存传统格式的结果
|
||
if self.config.output.save_combined_results:
|
||
self.save_results(results, "batch_spectral_indices")
|
||
|
||
print("=" * 60)
|
||
print("批量计算和可视化完成")
|
||
print(f"总计处理 {len(results)} 个光谱指数")
|
||
print(f"成功: {len(successful_results)} 个")
|
||
if failed_results:
|
||
print(f"失败: {len(failed_results)} 个")
|
||
print("=" * 60)
|
||
|
||
return results, fig
|
||
|
||
def save_results(self, results, output_prefix='output'):
|
||
"""
|
||
保存计算结果
|
||
|
||
参数:
|
||
results: 计算结果字典或数组
|
||
output_prefix: 输出文件前缀
|
||
"""
|
||
# 根据输入数据形状判断文件格式
|
||
is_csv_format = len(self.data_shape) == 2
|
||
|
||
if isinstance(results, dict):
|
||
if is_csv_format:
|
||
# CSV输入: 合并所有指数为单个CSV文件
|
||
self.save_merged_csv(results, output_prefix)
|
||
else:
|
||
# 图像输入: 将多个指数堆叠为多波段dat文件
|
||
self.save_multiband_dat(results, output_prefix)
|
||
elif isinstance(results, np.ndarray):
|
||
# 保存堆叠的结果
|
||
output_file = f"{output_prefix}_indices_stack"
|
||
self.save_single_result(results, "multiple_indices", output_file)
|
||
|
||
def save_single_result(self, data, index_name, output_file):
|
||
"""保存单个结果"""
|
||
# 确定文件格式
|
||
if output_file.endswith('.csv'):
|
||
# 保存为CSV
|
||
if len(data.shape) == 2:
|
||
df = pd.DataFrame(data)
|
||
else:
|
||
# 3D数据展平
|
||
flattened = data.reshape(-1, data.shape[-1]) if len(data.shape) == 3 else data.flatten()
|
||
df = pd.DataFrame(flattened)
|
||
|
||
df.to_csv(output_file, index=False)
|
||
print(f"已保存: {output_file} (CSV格式)")
|
||
|
||
else:
|
||
# 保存为二进制文件 + 头文件
|
||
data_file = output_file + '.dat'
|
||
hdr_file = output_file + '.hdr'
|
||
|
||
# 保存二进制数据
|
||
data.astype('float32').tofile(data_file)
|
||
|
||
# 创建ENVI头文件 - 符合ENVI标准格式
|
||
if len(data.shape) == 1:
|
||
# 1D数据 (CSV格式)
|
||
lines = 1
|
||
samples = data.shape[0]
|
||
bands = 1
|
||
elif len(data.shape) == 2:
|
||
# 2D数据 (图像格式)
|
||
lines = data.shape[0]
|
||
samples = data.shape[1]
|
||
bands = 1
|
||
else:
|
||
# 3D数据 (多波段图像)
|
||
lines = data.shape[0]
|
||
samples = data.shape[1]
|
||
bands = data.shape[2]
|
||
|
||
header_content = f"""ENVI
|
||
samples = {samples}
|
||
lines = {lines}
|
||
bands = {bands}
|
||
header offset = 0
|
||
file type = ENVI Standard
|
||
data type = 4
|
||
interleave = bsq
|
||
byte order = 0
|
||
wavelength units = Unknown
|
||
data ignore value = 0
|
||
"""
|
||
|
||
with open(hdr_file, 'w', encoding='utf-8') as f:
|
||
f.write(header_content)
|
||
|
||
print(f"已保存: {data_file} 和 {hdr_file} (ENVI格式)")
|
||
print(f"数据形状: {data.shape}")
|
||
|
||
def save_merged_csv(self, results_dict, output_prefix):
|
||
"""
|
||
将多个光谱指数结果合并保存为单个CSV文件
|
||
|
||
参数:
|
||
results_dict: 包含多个指数结果的字典
|
||
output_prefix: 输出文件前缀
|
||
"""
|
||
# 创建输出文件名
|
||
output_file = f"{output_prefix}_merged_indices.csv"
|
||
|
||
# 过滤掉None值的结果
|
||
valid_results = {name: data for name, data in results_dict.items() if data is not None}
|
||
|
||
if not valid_results:
|
||
print("没有有效的指数结果可保存")
|
||
return
|
||
|
||
# 获取第一个有效结果的形状作为参考
|
||
first_result = next(iter(valid_results.values()))
|
||
if len(first_result.shape) != 1:
|
||
print("合并CSV仅支持1D数据")
|
||
return
|
||
|
||
# 创建合并的数据框
|
||
merged_data = {}
|
||
|
||
# 添加每个指数的值
|
||
for index_name, data in valid_results.items():
|
||
merged_data[index_name] = data
|
||
|
||
# 创建DataFrame并保存
|
||
df = pd.DataFrame(merged_data)
|
||
df.to_csv(output_file, index=False)
|
||
print(f"已合并保存 {len(valid_results)} 个指数到: {output_file}")
|
||
print(f"总样本数: {len(df)}")
|
||
|
||
def save_multiband_dat(self, results_dict, output_prefix):
|
||
"""
|
||
将多个光谱指数结果保存为多波段的dat文件
|
||
|
||
参数:
|
||
results_dict: 包含多个指数结果的字典
|
||
output_prefix: 输出文件前缀
|
||
"""
|
||
# 过滤掉None值的结果
|
||
valid_results = {name: data for name, data in results_dict.items() if data is not None}
|
||
|
||
if not valid_results:
|
||
print("没有有效的指数结果可保存")
|
||
return
|
||
|
||
# 移除单维度(如果存在)
|
||
index_names = []
|
||
index_data_list = []
|
||
|
||
for index_name, data in valid_results.items():
|
||
index_names.append(index_name)
|
||
# 如果数据有单维度(例如形状为(1066, 1148, 1)),则压缩它
|
||
if data.ndim == 3 and data.shape[2] == 1:
|
||
data = np.squeeze(data, axis=2)
|
||
index_data_list.append(data)
|
||
|
||
# 检查所有数据是否都是二维的
|
||
for i, data in enumerate(index_data_list):
|
||
if data.ndim != 2:
|
||
print(f"警告: 指数 {index_names[i]} 的形状为 {data.shape},不是二维数组")
|
||
# 尝试展平为二维
|
||
if data.ndim > 2:
|
||
data = data.reshape(data.shape[0], data.shape[1])
|
||
else:
|
||
data = data.reshape(1, -1)
|
||
index_data_list[i] = data
|
||
|
||
# 堆叠为多波段数据 (行, 列, 波段)
|
||
stacked_data = np.stack(index_data_list, axis=-1)
|
||
|
||
# 保存为dat文件
|
||
output_file = f"{output_prefix}_indices"
|
||
data_file = output_file + '.dat'
|
||
hdr_file = output_file + '.hdr'
|
||
|
||
# 保存二进制数据
|
||
stacked_data.astype('float32').tofile(data_file)
|
||
|
||
# 创建ENVI头文件 - 符合ENVI标准格式
|
||
lines = stacked_data.shape[0]
|
||
samples = stacked_data.shape[1]
|
||
bands = stacked_data.shape[2]
|
||
|
||
header_content = f"""ENVI
|
||
samples = {samples}
|
||
lines = {lines}
|
||
bands = {bands}
|
||
header offset = 0
|
||
file type = ENVI Standard
|
||
data type = 4
|
||
interleave = bip
|
||
byte order = 0
|
||
wavelength units = Unknown
|
||
data ignore value = 0
|
||
band names = {{{', '.join(index_names)}}}
|
||
"""
|
||
|
||
with open(hdr_file, 'w', encoding='utf-8') as f:
|
||
f.write(header_content)
|
||
|
||
print(f"已保存多波段数据: {data_file} 和 {hdr_file}")
|
||
print(f"包含 {bands} 个指数: {', '.join(index_names)}")
|
||
print(f"数据形状: {stacked_data.shape}")
|
||
|
||
def list_available_indices(self, category=None):
|
||
"""
|
||
列出所有可用的光谱指数
|
||
|
||
参数:
|
||
category: 按类别过滤 (如 'Vegetation', 'Mineral')
|
||
"""
|
||
print("\n=== 可用光谱指数 ===")
|
||
|
||
indices_by_category = {}
|
||
|
||
for idx, info in self.formulas.items():
|
||
categories = info['type']
|
||
|
||
for cat in categories:
|
||
if cat not in indices_by_category:
|
||
indices_by_category[cat] = []
|
||
indices_by_category[cat].append((idx, info['name']))
|
||
|
||
# 显示所有类别或特定类别
|
||
if category:
|
||
if category in indices_by_category:
|
||
print(f"\n{category} 指数:")
|
||
for idx, name in sorted(indices_by_category[category]):
|
||
print(f" {idx:10} - {name}")
|
||
else:
|
||
print(f"未找到类别: {category}")
|
||
else:
|
||
for cat in sorted(indices_by_category.keys()):
|
||
print(f"\n{cat} ({len(indices_by_category[cat])}个指数):")
|
||
for idx, name in sorted(indices_by_category[cat])[:10]: # 只显示前10个
|
||
print(f" {idx:10} - {name}")
|
||
if len(indices_by_category[cat]) > 10:
|
||
print(f" ... 还有 {len(indices_by_category[cat]) - 10} 个")
|
||
|
||
def run_analysis_from_config(self) -> Dict[str, Any]:
|
||
"""
|
||
基于配置对象运行完整分析流程 - 推荐用于GUI对接
|
||
|
||
Returns:
|
||
Dict[str, Any]: 分析结果字典
|
||
"""
|
||
print("Starting spectral index analysis from configuration...")
|
||
|
||
# 1. 加载数据
|
||
if not self.config.data.data_file_path:
|
||
raise ValueError("Data file path must be specified in configuration")
|
||
|
||
print(f"Loading data from: {self.config.data.data_file_path}")
|
||
self.load_hyperspectral_data()
|
||
|
||
# 2. 确定要计算的指数
|
||
if self.config.indices.indices_to_calculate is None:
|
||
# 计算所有可用指数
|
||
indices_to_calculate = list(self.formulas.keys())
|
||
print(f"Calculating all {len(indices_to_calculate)} available indices")
|
||
else:
|
||
# 计算指定的指数
|
||
indices_to_calculate = self.config.indices.indices_to_calculate
|
||
invalid_indices = [idx for idx in indices_to_calculate if idx not in self.formulas]
|
||
if invalid_indices:
|
||
print(f"Warning: The following indices do not exist: {invalid_indices}")
|
||
indices_to_calculate = [idx for idx in indices_to_calculate if idx in self.formulas]
|
||
print(f"Calculating {len(indices_to_calculate)} specified indices: {indices_to_calculate}")
|
||
|
||
# 3. 计算指数
|
||
results = self.calculate_multiple_indices(indices_to_calculate, output_format='separate')
|
||
|
||
# 4. 保存结果
|
||
if self.config.output.save_combined_results or self.config.output.save_individual_indices:
|
||
self.save_results(results, 'config_results')
|
||
|
||
print("Analysis completed!")
|
||
return results
|
||
|
||
total = len(self.formulas)
|
||
print(f"\n总计: {total} 个光谱指数")
|
||
|
||
|
||
def main():
|
||
"""主函数 - 命令行接口"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='高光谱光谱指数计算工具')
|
||
parser.add_argument('input_file', help='输入的高光谱文件路径 (CSV, BIL, BSQ, BIP, DAT)')
|
||
parser.add_argument('-i', '--indices', nargs='+', help='要计算的指数列表 (如 NDVI EVI)')
|
||
parser.add_argument('-a', '--all', action='store_true', help='计算所有植被指数')
|
||
parser.add_argument('-A', '--all-indices', action='store_true', help='计算所有可用指数并生成PNG可视化')
|
||
parser.add_argument('-c', '--category', help='计算特定类别的所有指数 (Vegetation, Mineral等)')
|
||
parser.add_argument('-o', '--output', default='output', help='输出文件前缀')
|
||
parser.add_argument('-f', '--format', choices=['separate', 'stack'], default='separate',
|
||
help='输出格式: separate(分开) 或 stack(堆叠)')
|
||
parser.add_argument('-p', '--png', action='store_true', help='生成包含所有指数的PNG可视化')
|
||
parser.add_argument('-l', '--list', action='store_true', help='列出所有可用指数')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 初始化计算器
|
||
calculator = HyperspectralIndexCalculator()
|
||
|
||
# 如果只是列出指数
|
||
if args.list:
|
||
calculator.list_available_indices()
|
||
return
|
||
|
||
# 加载数据
|
||
print(f"加载数据: {args.input_file}")
|
||
calculator.load_hyperspectral_data(args.input_file)
|
||
|
||
# 特殊处理:批量计算所有指数并生成PNG
|
||
if args.all_indices:
|
||
print("批量计算所有可用光谱指数并生成PNG可视化...")
|
||
calculator.config.output.save_png_visualization = True
|
||
|
||
results, fig = calculator.batch_calculate_and_visualize(save_png=True)
|
||
|
||
# 打印结果摘要
|
||
valid_results = {k: v for k, v in results.items() if v is not None}
|
||
print(f"\n成功计算 {len(valid_results)}/{len(results)} 个指数")
|
||
return
|
||
|
||
# 确定要计算的指数
|
||
indices_to_calculate = []
|
||
|
||
if args.indices:
|
||
indices_to_calculate = args.indices
|
||
elif args.all:
|
||
# 计算所有植被指数
|
||
for idx, info in calculator.formulas.items():
|
||
if 'Vegetation' in info['type']:
|
||
indices_to_calculate.append(idx)
|
||
elif args.category:
|
||
# 计算特定类别的所有指数
|
||
for idx, info in calculator.formulas.items():
|
||
if args.category in info['type']:
|
||
indices_to_calculate.append(idx)
|
||
else:
|
||
# 默认计算常用植被指数
|
||
default_indices = ['NDVI', 'EVI', 'NDWI', 'NDII', 'MSI', 'GNDVI']
|
||
indices_to_calculate = [idx for idx in default_indices if idx in calculator.formulas]
|
||
|
||
if not indices_to_calculate:
|
||
print("错误: 未指定要计算的指数")
|
||
calculator.list_available_indices()
|
||
return
|
||
|
||
print(f"\n将计算 {len(indices_to_calculate)} 个指数:")
|
||
for idx in indices_to_calculate[:10]:
|
||
if idx in calculator.formulas:
|
||
print(f" - {idx}: {calculator.formulas[idx]['name']}")
|
||
if len(indices_to_calculate) > 10:
|
||
print(f" ... 还有 {len(indices_to_calculate) - 10} 个")
|
||
|
||
# 计算指数
|
||
results = calculator.calculate_multiple_indices(
|
||
indices_to_calculate,
|
||
output_format=args.format
|
||
)
|
||
|
||
# 保存结果
|
||
calculator.save_results(results, args.output)
|
||
|
||
# 如果指定了PNG可视化,生成可视化图像
|
||
if args.png and len(indices_to_calculate) > 1:
|
||
print("\n生成PNG可视化...")
|
||
try:
|
||
fig = calculator.create_indices_visualization(results)
|
||
print(f"可视化已保存为PNG格式")
|
||
except Exception as e:
|
||
print(f"生成PNG可视化失败: {e}")
|
||
|
||
print("\n计算完成!")
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
exit(main()) |