import numpy as np import pandas as pd import os import re import sys from pathlib import Path from typing import Optional, Dict, Any, Union from dataclasses import dataclass, field import warnings try: import spectral SPECTRAL_AVAILABLE = True except ImportError: SPECTRAL_AVAILABLE = False print("警告: spectral库不可用,将使用内置方法读取数据") warnings.filterwarnings('ignore') @dataclass class DataConfig: """数据配置类""" data_file_path: str = "" data_format: str = "auto" # 'csv', 'envi', 'auto' wavelength_column: Optional[int] = None # CSV格式中波长列的索引 label_column: Optional[str] = None # CSV格式中的标签列名 spectral_start: Optional[str] = None # 光谱数据起始列名或索引 spectral_end: Optional[str] = None # 光谱数据结束列名或索引 @dataclass class IndexConfig: """光谱指数配置类""" spectral_index_csv: str = r"E:\code\spectronon\spectral_index.csv" formula_csv: str = r"E:\code\spectronon\famula.csv" indices_to_calculate: Optional[list] = None # 要计算的指数列表,None表示全部 @dataclass class OutputConfig: """输出配置类""" output_dir: str = "results" save_individual_indices: bool = True save_combined_results: bool = True output_format: str = "csv" # 'csv', 'excel', 'both' save_png_visualization: bool = False # 是否保存PNG可视化 png_filename: str = "spectral_indices_overview.png" @dataclass class SpectralIndexConfig: """光谱指数计算完整配置类 - 为GUI对接设计的标准化接口""" data: DataConfig = field(default_factory=DataConfig) indices: IndexConfig = field(default_factory=IndexConfig) output: OutputConfig = field(default_factory=OutputConfig) def __post_init__(self): """参数校验和默认值设置""" self._validate_parameters() def _validate_parameters(self): """参数校验""" # 数据参数校验 if self.data.data_format not in ['csv', 'envi', 'auto']: raise ValueError("Data format must be 'csv', 'envi', or 'auto'") # 输出参数校验 if self.output.output_format not in ['csv', 'excel', 'both']: raise ValueError("Output format must be 'csv', 'excel', or 'both'") # 文件路径校验 if not os.path.exists(self.indices.spectral_index_csv): print(f"Warning: Spectral index CSV file not found: {self.indices.spectral_index_csv}") if not os.path.exists(self.indices.formula_csv): print(f"Warning: Formula CSV file not found: {self.indices.formula_csv}") @classmethod def create_default(cls) -> 'SpectralIndexConfig': """创建默认配置""" return cls() @classmethod def create_quick_analysis(cls, data_file_path: str, indices_to_calculate: Optional[list] = None) -> 'SpectralIndexConfig': """创建快速分析配置""" config = cls() config.data.data_file_path = data_file_path if indices_to_calculate: config.indices.indices_to_calculate = indices_to_calculate config.output.save_individual_indices = False # 快速分析不保存单个指数 return config class HyperspectralIndexCalculator: """ 高光谱指数计算器 - 支持GUI对接的标准化接口 支持多种光谱指数的自动计算 """ def __init__(self, config: Optional[SpectralIndexConfig] = None, spectral_index_csv=None, formula_csv=None): """ 初始化光谱指数计算器 Parameters: config (SpectralIndexConfig, optional): 配置对象,如果为None则使用默认配置 spectral_index_csv (str, optional): 波段定义文件路径(向后兼容) formula_csv (str, optional): 光谱指数公式文件路径(向后兼容) """ # 处理向后兼容性 if config is None and (spectral_index_csv is not None or formula_csv is not None): # 使用传统参数方式 config = SpectralIndexConfig() if spectral_index_csv is not None: config.indices.spectral_index_csv = spectral_index_csv if formula_csv is not None: config.indices.formula_csv = formula_csv self.config = config or SpectralIndexConfig() self._validate_config() # 加载波段定义 self.band_info = self.load_band_info(self.config.indices.spectral_index_csv) # 加载光谱指数公式 self.formulas = self.load_formulas(self.config.indices.formula_csv) # 存储当前加载的数据 self.data = None self.wavelengths = None self.data_shape = None self.band_mapping = None def update_config(self, config: SpectralIndexConfig): """ 更新配置 - 为GUI动态配置预留接口 Parameters: config (SpectralIndexConfig): 新的配置对象 """ self.config = config self._validate_config() def _validate_config(self): """配置校验""" try: self.config._validate_parameters() except ValueError as e: raise ValueError(f"Configuration validation failed: {e}") def load_band_info(self, csv_path): """加载波段定义信息""" df = pd.read_csv(csv_path) band_info = {} for _, row in df.iterrows(): name = row['name'] band_info[name] = { 'min': float(row['min']), 'center': float(row['center']), 'max': float(row['max']) } # 为常用波段名称创建别名映射 if name in ['nir', 'red', 'green', 'blue', 'swir1', 'swir2']: band_info[name.upper()] = band_info[name] return band_info def load_formulas(self, csv_path): """加载光谱指数公式""" df = pd.read_csv(csv_path) formulas = {} for _, row in df.iterrows(): index = row['Index'] formulas[index] = { 'name': row['Name'], 'type': eval(row['Type']), # 将字符串列表转换为列表 'equation': row['Equation'], 'bands': eval(row['Bands']) # 将字符串列表转换为列表 } return formulas def load_hyperspectral_data(self, file_path=None): """ 加载高光谱数据文件 支持格式: - CSV: 第一行为波长,后续为数据 - ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (使用spectral库) Args: file_path: 文件路径,如果为None则从配置中获取 """ if file_path is None: file_path = self.config.data.data_file_path file_path = Path(file_path) suffix = file_path.suffix.lower() if suffix == '.csv': return self.load_csv_data(file_path) else: # 使用spectral库 if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']: return self.load_with_spectral(file_path) else: raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件,并确保已安装spectral库。") def load_csv_data(self, file_path): """加载CSV格式的高光谱数据""" df = pd.read_csv(file_path) # 处理配置中的参数 label_column = self.config.data.label_column spectral_start = self.config.data.spectral_start spectral_end = self.config.data.spectral_end # 确定光谱数据列 if spectral_start is not None and spectral_end is not None: # 使用指定的列范围 try: start_idx = int(spectral_start) if spectral_start.isdigit() else df.columns.get_loc(spectral_start) end_idx = int(spectral_end) if spectral_end.isdigit() else df.columns.get_loc(spectral_end) + 1 spectral_columns = df.columns[start_idx:end_idx] except (ValueError, KeyError): raise ValueError(f"无法找到指定的光谱列范围: {spectral_start} 到 {spectral_end}") else: # 自动检测:排除标签列的所有列 if label_column and label_column in df.columns: spectral_columns = df.columns.drop(label_column) else: # 如果没有指定标签列,假设所有列都是光谱数据 spectral_columns = df.columns # 提取光谱数据 spectral_data = df[spectral_columns] # 检查是否有标签列 if label_column and label_column in df.columns: self.labels = df[label_column].values print(f"检测到标签列 '{label_column}',包含 {len(self.labels)} 个样本") else: self.labels = None # 尝试将列名转换为波长 try: self.wavelengths = spectral_columns.astype(float).values print(f"成功将列名转换为波长") except (ValueError, TypeError): # 如果列名不是数字,创建默认波长 self.wavelengths = np.arange(len(spectral_columns), dtype=float) print(f"列名不是数字波长,使用默认波长索引 0-{len(spectral_columns)-1}") self.data = spectral_data.values self.data_shape = self.data.shape print(f"CSV数据加载成功: {self.data_shape}") print(f"波段数量: {len(self.wavelengths)}") if self.wavelengths is not None: print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm") return self.data def load_with_spectral(self, file_path): """ 使用spectral库加载高光谱数据 支持ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 """ original_path = Path(file_path) current_path = original_path # 如果是.hdr文件,找到对应的数据文件 if current_path.suffix.lower() == '.hdr': data_file = current_path.with_suffix('') if not data_file.exists(): # 尝试其他常见扩展名 for ext in ['.dat', '.bil', '.bsq', '.bip']: candidate = current_path.with_suffix(ext) if candidate.exists(): data_file = candidate break current_path = data_file print(f"使用spectral库加载: {current_path}") # 使用spectral打开图像 img = spectral.open_image(file_path) # 读取所有波段的数据 # spectral默认返回(行, 列, 波段)的numpy数组 self.data = img.load() # 获取波长信息 if hasattr(img, 'metadata') and 'wavelength' in img.metadata: wavelength_data = img.metadata['wavelength'] try: # 处理波长数据,可能需要转换为float数组 if isinstance(wavelength_data, str): # 如果是字符串,尝试解析 wavelength_data = wavelength_data.strip('{}[]') if ',' in wavelength_data: self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()]) else: self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()]) elif isinstance(wavelength_data, list): self.wavelengths = np.array([float(w) for w in wavelength_data]) else: self.wavelengths = np.array(wavelength_data, dtype=float) print(f"spectral解析波长: {len(self.wavelengths)} 个波段") print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm") except Exception as e: print(f"波长解析失败,使用默认值: {e}") self.wavelengths = np.arange(self.data.shape[2]) else: # 如果没有波长信息,创建默认值 self.wavelengths = np.arange(self.data.shape[2]) print(f"警告: spectral未找到波长信息,使用默认值: 0-{self.data.shape[2]-1}") self.data_shape = self.data.shape print(f"spectral数据加载成功: {self.data_shape}") print(f"数据类型: {self.data.dtype}") print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]") return self.data def find_band_index(self, band_name, tolerance=5): """ 根据波段名称找到对应的波段索引 参数: band_name: 波段名称 (如 'w550', 'nir', 'red') tolerance: 容差范围 (nm) 返回: 波段索引 (从0开始) """ if band_name not in self.band_info: # 尝试匹配标准波段名称 std_names = { 'nir': 'w860', 'red': 'w680', 'green': 'w550', 'blue': 'w470', 'swir1': 'w1650', 'swir2': 'w2200', 'tm1': 'w485', 'tm2': 'w569', 'tm3': 'w660', 'tm4': 'w833', 'tm5': 'w1676', 'tm7': 'w2223' } if band_name in std_names: band_name = std_names[band_name] else: raise ValueError(f"未定义的波段名称: {band_name}") target_info = self.band_info[band_name] target_center = target_info['center'] target_min = target_info['min'] target_max = target_info['max'] # 检查当前数据的波长范围是否包含目标波段 data_min = self.wavelengths[0] data_max = self.wavelengths[-1] if target_min < data_min or target_max > data_max: raise ValueError( f"波段 {band_name} 超出数据波长范围: " f"目标波段 [{target_min:.1f}-{target_max:.1f}nm], " f"数据范围 [{data_min:.1f}-{data_max:.1f}nm]" ) # 找到最接近的波长 diffs = np.abs(self.wavelengths - target_center) min_idx = np.argmin(diffs) if diffs[min_idx] > tolerance: print( f"警告: 波段 {band_name} (目标: {target_center}nm) 匹配到 {self.wavelengths[min_idx]:.1f}nm (差异: {diffs[min_idx]:.1f}nm)") return min_idx def calculate_index(self, index_name, output_type='image'): """ 计算指定光谱指数 参数: index_name: 光谱指数简称 (如 'NDVI', 'EVI') output_type: 输出类型 ('image'或'values') 返回: 光谱指数图像或数值 """ if index_name not in self.formulas: available = list(self.formulas.keys()) raise ValueError(f"未找到指数: {index_name}。可用指数: {available}") formula_info = self.formulas[index_name] equation = formula_info['equation'] required_bands = formula_info['bands'] print(f"计算指数: {formula_info['name']} ({index_name})") print(f"所需波段: {required_bands}") print(f"公式: {equation}") # 检查公式中是否包含临时变量 if ';' in equation: # 分离主公式和临时变量定义 parts = equation.split(';') main_eq = parts[0].strip() temp_vars = parts[1:] # 创建局部变量字典 local_vars = {} # 先计算临时变量 for temp_eq in temp_vars: if '=' in temp_eq: var_name, expr = temp_eq.split('=', 1) var_name = var_name.strip() expr = expr.strip() # 计算临时变量 temp_value = self.evaluate_expression(expr, required_bands, local_vars) local_vars[var_name] = temp_value else: main_eq = equation.strip() local_vars = {} # 计算主公式 index_result = self.evaluate_expression(main_eq, required_bands, local_vars) # 根据输出类型返回结果 if output_type == 'image': return index_result elif output_type == 'values': # 展平为一维数组 return index_result.flatten() else: return index_result def evaluate_expression(self, expression, required_bands, local_vars): """ 评估表达式 参数: expression: 表达式字符串 required_bands: 所需波段列表 local_vars: 局部变量字典 """ # 提取所有波段变量 pattern = r'\b(w\d+|tm\d+|nir|red|green|blue|swir1|swir2|thermal)\b' band_vars = re.findall(pattern, expression) # 创建波段数据字典 band_data = {} for band_var in set(band_vars): # 检查是否已在局部变量中 if band_var in local_vars: continue # 查找波段索引 band_idx = self.find_band_index(band_var) # 获取波段数据 if len(self.data_shape) == 2: # CSV格式 band_data[band_var] = self.data[:, band_idx] else: # 图像格式 band_data[band_var] = self.data[:, :, band_idx] # 合并局部变量 all_vars = {**band_data, **local_vars} # 添加数学函数 math_funcs = { 'sqrt': np.sqrt, 'log': np.log, 'log10': np.log10, 'exp': np.exp, 'sin': np.sin, 'cos': np.cos, 'tan': np.tan, 'abs': np.abs, 'pow': np.power } # 安全评估表达式 try: # 替换表达式中的^为** expression = expression.replace('^', '**') # 创建安全环境 safe_env = {**all_vars, **math_funcs} # 添加numpy函数 safe_env['np'] = np # 执行表达式 result = eval(expression, {"__builtins__": {}}, safe_env) # 处理可能的异常值 if isinstance(result, np.ndarray): result = np.nan_to_num(result, nan=0.0, posinf=1.0, neginf=-1.0) return result except Exception as e: raise ValueError(f"表达式评估失败: {expression}\n错误: {str(e)}") def calculate_multiple_indices(self, index_list, output_format='separate'): """ 计算多个光谱指数 参数: index_list: 指数名称列表 output_format: 输出格式 ('separate'或'stack') 返回: 光谱指数结果 """ results = {} for index_name in index_list: try: result = self.calculate_index(index_name, output_type='image') results[index_name] = result print(f"✓ {index_name}: 计算完成 (形状: {result.shape})") except Exception as e: print(f"✗ {index_name}: 计算失败 - {str(e)}") results[index_name] = None if output_format == 'stack' and results: # 将所有指数堆叠成一个多波段图像 valid_results = [r for r in results.values() if r is not None] if valid_results: return np.stack(valid_results, axis=-1) return results def calculate_all_indices(self): """ 计算所有可用的光谱指数 返回: dict: 包含所有指数结果的字典 """ if self.config.indices.indices_to_calculate: index_list = self.config.indices.indices_to_calculate else: index_list = list(self.formulas.keys()) print(f"开始计算 {len(index_list)} 个光谱指数...") return self.calculate_multiple_indices(index_list, output_format='separate') def create_indices_visualization(self, results=None, save_path=None): """ 创建包含所有光谱指数的PNG可视化 参数: results: 指数计算结果字典,如果为None则自动计算 save_path: 保存路径,如果为None则使用配置中的路径 返回: matplotlib.figure.Figure or None: 可视化图形对象,如果没有有效结果则返回None """ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from matplotlib import cm # 如果没有提供结果,计算所有指数 if results is None: results = self.calculate_all_indices() # 过滤掉计算失败的指数 valid_results = {k: v for k, v in results.items() if v is not None} if not valid_results: print("警告: 没有有效的指数计算结果可以可视化") return None n_indices = len(valid_results) print(f"为 {n_indices} 个有效指数创建可视化...") # 计算网格布局 if n_indices <= 4: nrows, ncols = 2, 2 elif n_indices <= 9: nrows, ncols = 3, 3 elif n_indices <= 16: nrows, ncols = 4, 4 else: # 对于更多指数,使用更大的网格 nrows = int(np.ceil(np.sqrt(n_indices))) ncols = int(np.ceil(n_indices / nrows)) # 创建图形 fig = plt.figure(figsize=(ncols * 4, nrows * 4)) gs = gridspec.GridSpec(nrows, ncols, figure=fig, hspace=0.3, wspace=0.3, top=0.95, bottom=0.05, left=0.05, right=0.95) # 选择颜色映射 cmap = cm.get_cmap('RdYlBu_r') # 为每个指数创建子图 plotted_count = 0 for i, (index_name, index_data) in enumerate(valid_results.items()): if plotted_count >= nrows * ncols: print(f"警告: 只显示前 {nrows*ncols} 个指数,共有 {n_indices} 个有效指数") break row = plotted_count // ncols col = plotted_count % ncols try: ax = fig.add_subplot(gs[row, col]) # 处理不同维度的数据 if len(index_data.shape) == 2: # 2D图像数据 im = ax.imshow(index_data, cmap=cmap, aspect='equal') plt.colorbar(im, ax=ax, shrink=0.8) elif len(index_data.shape) == 3 and index_data.shape[2] == 1: # 单波段3D数据 im = ax.imshow(index_data[:, :, 0], cmap=cmap, aspect='equal') plt.colorbar(im, ax=ax, shrink=0.8) else: # 其他格式的数据,尝试展平显示 if len(index_data.shape) > 2: index_data = index_data.reshape(index_data.shape[0], -1) # 显示为热力图 im = ax.imshow(index_data, cmap=cmap, aspect='auto') plt.colorbar(im, ax=ax, shrink=0.8) # 设置标题 formula_info = self.formulas.get(index_name, {}) full_name = formula_info.get('name', index_name) ax.set_title(f'{full_name}\n({index_name})', fontsize=10, fontweight='bold') # 移除坐标轴标签以节省空间 ax.set_xticks([]) ax.set_yticks([]) plotted_count += 1 except Exception as e: print(f"创建指数 {index_name} 的子图失败: {e}") continue # 如果没有成功创建任何子图,返回None if plotted_count == 0: plt.close(fig) print("警告: 未能创建任何有效的可视化子图") return None # 设置总标题 fig.suptitle(f'光谱指数可视化总览\n成功显示 {plotted_count}/{n_indices} 个指数', fontsize=14, fontweight='bold', y=0.98) # 调整布局 plt.tight_layout() # 保存图像 if save_path is None: save_path = os.path.join(self.config.output.output_dir, self.config.output.png_filename) try: # 确保输出目录存在 os.makedirs(os.path.dirname(save_path), exist_ok=True) # 保存为高分辨率PNG fig.savefig(save_path, dpi=300, bbox_inches='tight') print(f"光谱指数可视化已保存到: {save_path}") except Exception as e: print(f"保存可视化图像失败: {e}") plt.close(fig) return None return fig def batch_calculate_and_visualize(self, save_png=True): """ 批量计算所有光谱指数并生成PNG可视化 参数: save_png: 是否保存PNG可视化 返回: tuple: (results_dict, figure_object) """ print("=" * 60) print("开始批量光谱指数计算和可视化") print("=" * 60) # 确保数据已经加载 if self.data is None: if hasattr(self.config, 'data') and self.config.data.data_file_path: print(f"加载数据文件: {self.config.data.data_file_path}") self.load_hyperspectral_data() else: raise ValueError("未指定数据文件路径,无法进行批量计算") # 计算所有指数 results = self.calculate_all_indices() # 统计计算结果 successful_results = {k: v for k, v in results.items() if v is not None} failed_results = {k: v for k, v in results.items() if v is None} print(f"计算完成: {len(successful_results)}/{len(results)} 个指数成功") # 生成可视化(只使用成功的计算结果) fig = None if save_png and self.config.output.save_png_visualization: if len(successful_results) > 0: try: fig = self.create_indices_visualization(successful_results) print(f"可视化已保存到: {os.path.join(self.config.output.output_dir, self.config.output.png_filename)}") except Exception as e: print(f"生成可视化失败: {e}") fig = None else: print("警告: 没有成功的指数计算结果,跳过可视化生成") # 保存传统格式的结果 if self.config.output.save_combined_results: self.save_results(results, "batch_spectral_indices") print("=" * 60) print("批量计算和可视化完成") print(f"总计处理 {len(results)} 个光谱指数") print(f"成功: {len(successful_results)} 个") if failed_results: print(f"失败: {len(failed_results)} 个") print("=" * 60) return results, fig def save_results(self, results, output_prefix='output'): """ 保存计算结果 参数: results: 计算结果字典或数组 output_prefix: 输出文件前缀 """ # 根据输入数据形状判断文件格式 is_csv_format = len(self.data_shape) == 2 if isinstance(results, dict): if is_csv_format: # CSV输入: 合并所有指数为单个CSV文件 self.save_merged_csv(results, output_prefix) else: # 图像输入: 将多个指数堆叠为多波段dat文件 self.save_multiband_dat(results, output_prefix) elif isinstance(results, np.ndarray): # 保存堆叠的结果 output_file = f"{output_prefix}_indices_stack" self.save_single_result(results, "multiple_indices", output_file) def save_single_result(self, data, index_name, output_file): """保存单个结果""" # 确定文件格式 if output_file.endswith('.csv'): # 保存为CSV if len(data.shape) == 2: df = pd.DataFrame(data) else: # 3D数据展平 flattened = data.reshape(-1, data.shape[-1]) if len(data.shape) == 3 else data.flatten() df = pd.DataFrame(flattened) df.to_csv(output_file, index=False) print(f"已保存: {output_file} (CSV格式)") else: # 保存为二进制文件 + 头文件 data_file = output_file + '.dat' hdr_file = output_file + '.hdr' # 保存二进制数据 data.astype('float32').tofile(data_file) # 创建ENVI头文件 - 符合ENVI标准格式 if len(data.shape) == 1: # 1D数据 (CSV格式) lines = 1 samples = data.shape[0] bands = 1 elif len(data.shape) == 2: # 2D数据 (图像格式) lines = data.shape[0] samples = data.shape[1] bands = 1 else: # 3D数据 (多波段图像) lines = data.shape[0] samples = data.shape[1] bands = data.shape[2] header_content = f"""ENVI samples = {samples} lines = {lines} bands = {bands} header offset = 0 file type = ENVI Standard data type = 4 interleave = bsq byte order = 0 wavelength units = Unknown data ignore value = 0 """ with open(hdr_file, 'w', encoding='utf-8') as f: f.write(header_content) print(f"已保存: {data_file} 和 {hdr_file} (ENVI格式)") print(f"数据形状: {data.shape}") def save_merged_csv(self, results_dict, output_prefix): """ 将多个光谱指数结果合并保存为单个CSV文件 参数: results_dict: 包含多个指数结果的字典 output_prefix: 输出文件前缀 """ # 创建输出文件名 output_file = f"{output_prefix}_merged_indices.csv" # 过滤掉None值的结果 valid_results = {name: data for name, data in results_dict.items() if data is not None} if not valid_results: print("没有有效的指数结果可保存") return # 获取第一个有效结果的形状作为参考 first_result = next(iter(valid_results.values())) if len(first_result.shape) != 1: print("合并CSV仅支持1D数据") return # 创建合并的数据框 merged_data = {} # 添加每个指数的值 for index_name, data in valid_results.items(): merged_data[index_name] = data # 创建DataFrame并保存 df = pd.DataFrame(merged_data) df.to_csv(output_file, index=False) print(f"已合并保存 {len(valid_results)} 个指数到: {output_file}") print(f"总样本数: {len(df)}") def save_multiband_dat(self, results_dict, output_prefix): """ 将多个光谱指数结果保存为多波段的dat文件 参数: results_dict: 包含多个指数结果的字典 output_prefix: 输出文件前缀 """ # 过滤掉None值的结果 valid_results = {name: data for name, data in results_dict.items() if data is not None} if not valid_results: print("没有有效的指数结果可保存") return # 移除单维度(如果存在) index_names = [] index_data_list = [] for index_name, data in valid_results.items(): index_names.append(index_name) # 如果数据有单维度(例如形状为(1066, 1148, 1)),则压缩它 if data.ndim == 3 and data.shape[2] == 1: data = np.squeeze(data, axis=2) index_data_list.append(data) # 检查所有数据是否都是二维的 for i, data in enumerate(index_data_list): if data.ndim != 2: print(f"警告: 指数 {index_names[i]} 的形状为 {data.shape},不是二维数组") # 尝试展平为二维 if data.ndim > 2: data = data.reshape(data.shape[0], data.shape[1]) else: data = data.reshape(1, -1) index_data_list[i] = data # 堆叠为多波段数据 (行, 列, 波段) stacked_data = np.stack(index_data_list, axis=-1) # 保存为dat文件 output_file = f"{output_prefix}_indices" data_file = output_file + '.dat' hdr_file = output_file + '.hdr' # 保存二进制数据 stacked_data.astype('float32').tofile(data_file) # 创建ENVI头文件 - 符合ENVI标准格式 lines = stacked_data.shape[0] samples = stacked_data.shape[1] bands = stacked_data.shape[2] header_content = f"""ENVI samples = {samples} lines = {lines} bands = {bands} header offset = 0 file type = ENVI Standard data type = 4 interleave = bip byte order = 0 wavelength units = Unknown data ignore value = 0 band names = {{{', '.join(index_names)}}} """ with open(hdr_file, 'w', encoding='utf-8') as f: f.write(header_content) print(f"已保存多波段数据: {data_file} 和 {hdr_file}") print(f"包含 {bands} 个指数: {', '.join(index_names)}") print(f"数据形状: {stacked_data.shape}") def list_available_indices(self, category=None): """ 列出所有可用的光谱指数 参数: category: 按类别过滤 (如 'Vegetation', 'Mineral') """ print("\n=== 可用光谱指数 ===") indices_by_category = {} for idx, info in self.formulas.items(): categories = info['type'] for cat in categories: if cat not in indices_by_category: indices_by_category[cat] = [] indices_by_category[cat].append((idx, info['name'])) # 显示所有类别或特定类别 if category: if category in indices_by_category: print(f"\n{category} 指数:") for idx, name in sorted(indices_by_category[category]): print(f" {idx:10} - {name}") else: print(f"未找到类别: {category}") else: for cat in sorted(indices_by_category.keys()): print(f"\n{cat} ({len(indices_by_category[cat])}个指数):") for idx, name in sorted(indices_by_category[cat])[:10]: # 只显示前10个 print(f" {idx:10} - {name}") if len(indices_by_category[cat]) > 10: print(f" ... 还有 {len(indices_by_category[cat]) - 10} 个") def run_analysis_from_config(self) -> Dict[str, Any]: """ 基于配置对象运行完整分析流程 - 推荐用于GUI对接 Returns: Dict[str, Any]: 分析结果字典 """ print("Starting spectral index analysis from configuration...") # 1. 加载数据 if not self.config.data.data_file_path: raise ValueError("Data file path must be specified in configuration") print(f"Loading data from: {self.config.data.data_file_path}") self.load_hyperspectral_data() # 2. 确定要计算的指数 if self.config.indices.indices_to_calculate is None: # 计算所有可用指数 indices_to_calculate = list(self.formulas.keys()) print(f"Calculating all {len(indices_to_calculate)} available indices") else: # 计算指定的指数 indices_to_calculate = self.config.indices.indices_to_calculate invalid_indices = [idx for idx in indices_to_calculate if idx not in self.formulas] if invalid_indices: print(f"Warning: The following indices do not exist: {invalid_indices}") indices_to_calculate = [idx for idx in indices_to_calculate if idx in self.formulas] print(f"Calculating {len(indices_to_calculate)} specified indices: {indices_to_calculate}") # 3. 计算指数 results = self.calculate_multiple_indices(indices_to_calculate, output_format='separate') # 4. 保存结果 if self.config.output.save_combined_results or self.config.output.save_individual_indices: self.save_results(results, 'config_results') print("Analysis completed!") return results total = len(self.formulas) print(f"\n总计: {total} 个光谱指数") def main(): """主函数 - 命令行接口""" import argparse parser = argparse.ArgumentParser(description='高光谱光谱指数计算工具') parser.add_argument('input_file', help='输入的高光谱文件路径 (CSV, BIL, BSQ, BIP, DAT)') parser.add_argument('-i', '--indices', nargs='+', help='要计算的指数列表 (如 NDVI EVI)') parser.add_argument('-a', '--all', action='store_true', help='计算所有植被指数') parser.add_argument('-A', '--all-indices', action='store_true', help='计算所有可用指数并生成PNG可视化') parser.add_argument('-c', '--category', help='计算特定类别的所有指数 (Vegetation, Mineral等)') parser.add_argument('-o', '--output', default='output', help='输出文件前缀') parser.add_argument('-f', '--format', choices=['separate', 'stack'], default='separate', help='输出格式: separate(分开) 或 stack(堆叠)') parser.add_argument('-p', '--png', action='store_true', help='生成包含所有指数的PNG可视化') parser.add_argument('-l', '--list', action='store_true', help='列出所有可用指数') args = parser.parse_args() # 初始化计算器 calculator = HyperspectralIndexCalculator() # 如果只是列出指数 if args.list: calculator.list_available_indices() return # 加载数据 print(f"加载数据: {args.input_file}") calculator.load_hyperspectral_data(args.input_file) # 特殊处理:批量计算所有指数并生成PNG if args.all_indices: print("批量计算所有可用光谱指数并生成PNG可视化...") calculator.config.output.save_png_visualization = True results, fig = calculator.batch_calculate_and_visualize(save_png=True) # 打印结果摘要 valid_results = {k: v for k, v in results.items() if v is not None} print(f"\n成功计算 {len(valid_results)}/{len(results)} 个指数") return # 确定要计算的指数 indices_to_calculate = [] if args.indices: indices_to_calculate = args.indices elif args.all: # 计算所有植被指数 for idx, info in calculator.formulas.items(): if 'Vegetation' in info['type']: indices_to_calculate.append(idx) elif args.category: # 计算特定类别的所有指数 for idx, info in calculator.formulas.items(): if args.category in info['type']: indices_to_calculate.append(idx) else: # 默认计算常用植被指数 default_indices = ['NDVI', 'EVI', 'NDWI', 'NDII', 'MSI', 'GNDVI'] indices_to_calculate = [idx for idx in default_indices if idx in calculator.formulas] if not indices_to_calculate: print("错误: 未指定要计算的指数") calculator.list_available_indices() return print(f"\n将计算 {len(indices_to_calculate)} 个指数:") for idx in indices_to_calculate[:10]: if idx in calculator.formulas: print(f" - {idx}: {calculator.formulas[idx]['name']}") if len(indices_to_calculate) > 10: print(f" ... 还有 {len(indices_to_calculate) - 10} 个") # 计算指数 results = calculator.calculate_multiple_indices( indices_to_calculate, output_format=args.format ) # 保存结果 calculator.save_results(results, args.output) # 如果指定了PNG可视化,生成可视化图像 if args.png and len(indices_to_calculate) > 1: print("\n生成PNG可视化...") try: fig = calculator.create_indices_visualization(results) print(f"可视化已保存为PNG格式") except Exception as e: print(f"生成PNG可视化失败: {e}") print("\n计算完成!") if __name__ == '__main__': exit(main())