#!/usr/bin/env python # -*- coding: utf-8 -*- """ 水质参数反演可视化与报告生成模块 功能包括: 1. 散点图:模型评估(真实值vs预测值) 2. 含量分布图:空间可视化,彩色填充图 3. 光谱曲线图:不同参数值的光谱曲线对比 4. 统计图表:箱线图、直方图、相关性热力图等 5. 模型训练摘要报告:training_summary.csv 6. 参数反演结果报告:包含预测统计信息 7. 批量处理摘要:batch_inference_summary.json 8. 掩膜和耀斑缩略图:2_glint和3_deglint文件夹的影像预览图 """ import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.patches as patches import seaborn as sns from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import json import warnings from datetime import datetime import joblib # 导入GDAL用于影像读写 try: from osgeo import gdal GDAL_AVAILABLE = True except ImportError: GDAL_AVAILABLE = False print("警告: GDAL未安装,影像预览图生成功能可能无法正常工作") warnings.filterwarnings('ignore') # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS'] plt.rcParams['axes.unicode_minus'] = False plt.rcParams['font.size'] = 12 # 设置seaborn样式 sns.set_style("whitegrid") sns.set_palette("husl") class WaterQualityVisualization: """水质参数反演可视化类""" def __init__(self, output_dir: str = "./visualization_output"): """ 初始化可视化类 Args: output_dir: 输出目录 """ self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) def plot_scatter_true_vs_pred(self, y_true: np.ndarray, y_pred: np.ndarray, target_name: str = "参数", train_indices: Optional[np.ndarray] = None, test_indices: Optional[np.ndarray] = None, metrics: Optional[Dict] = None, output_path: Optional[str] = None) -> str: """ 绘制散点图:真实值vs预测值 Args: y_true: 真实值 y_pred: 预测值 target_name: 目标参数名称 train_indices: 训练集索引(可选) test_indices: 测试集索引(可选) metrics: 评估指标字典(可选) output_path: 输出路径(如果为None,自动生成) Returns: 保存的文件路径 """ fig, ax = plt.subplots(figsize=(10, 8)) # 计算所有数据的R²和RMSE from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error r2_all = r2_score(y_true, y_pred) rmse_all = np.sqrt(mean_squared_error(y_true, y_pred)) mae_all = mean_absolute_error(y_true, y_pred) # 如果提供了训练/测试集索引,分别绘制 if train_indices is not None and test_indices is not None: y_train_true = y_true[train_indices] y_train_pred = y_pred[train_indices] y_test_true = y_true[test_indices] y_test_pred = y_pred[test_indices] # 绘制训练集散点 ax.scatter(y_train_true, y_train_pred, alpha=0.6, s=50, label=f'训练集 (n={len(y_train_true)})', color='blue', edgecolors='black', linewidths=0.5) # 绘制测试集散点 ax.scatter(y_test_true, y_test_pred, alpha=0.6, s=50, label=f'测试集 (n={len(y_test_true)})', color='red', edgecolors='black', linewidths=0.5) # 计算训练集和测试集指标 if metrics is None: train_r2 = r2_score(y_train_true, y_train_pred) test_r2 = r2_score(y_test_true, y_test_pred) train_rmse = np.sqrt(mean_squared_error(y_train_true, y_train_pred)) test_rmse = np.sqrt(mean_squared_error(y_test_true, y_test_pred)) else: train_r2 = metrics.get('train_r2', r2_score(y_train_true, y_train_pred)) test_r2 = metrics.get('test_r2', r2_score(y_test_true, y_test_pred)) train_rmse = metrics.get('train_rmse', np.sqrt(mean_squared_error(y_train_true, y_train_pred))) test_rmse = metrics.get('test_rmse', np.sqrt(mean_squared_error(y_test_true, y_test_pred))) metrics_text = f'训练集: R² = {train_r2:.4f}, RMSE = {train_rmse:.4f}\n' metrics_text += f'测试集: R² = {test_r2:.4f}, RMSE = {test_rmse:.4f}' else: # 绘制所有数据 ax.scatter(y_true, y_pred, alpha=0.6, s=50, color='blue', edgecolors='black', linewidths=0.5) metrics_text = f'R² = {r2_all:.4f}, RMSE = {rmse_all:.4f}, MAE = {mae_all:.4f}' # 绘制1:1线 min_val = min(y_true.min(), y_pred.min()) max_val = max(y_true.max(), y_pred.max()) ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='1:1线') # 设置图形属性 ax.set_xlabel(f'真实值 ({target_name})', fontsize=14, fontweight='bold') ax.set_ylabel(f'预测值 ({target_name})', fontsize=14, fontweight='bold') ax.set_title(f'{target_name} - 真实值 vs 预测值', fontsize=16, fontweight='bold') ax.legend(loc='upper left', fontsize=11) ax.grid(True, alpha=0.3) # 添加指标文本框 ax.text(0.05, 0.95, metrics_text, transform=ax.transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8), fontsize=10) plt.tight_layout() # 保存图片 if output_path is None: safe_name = "".join(c for c in target_name if c.isalnum() or c in ('-', '_', '.')) output_path = self.output_dir / f"{safe_name}_scatter_true_vs_pred.png" else: output_path = Path(output_path) plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() print(f"散点图已保存: {output_path}") return str(output_path) def plot_spectrum_by_parameter(self, csv_path: str, parameter_column: str, wavelength_start_column: Union[str, int] = "UTM_Y", output_dir: Optional[str] = None, wavelength_range: Optional[Tuple[float, float]] = None, n_groups: int = 5) -> str: """ 绘制光谱曲线图:不同参数值的光谱曲线对比 Args: csv_path: 包含光谱和参数值的CSV文件路径 parameter_column: 参数值列名或索引 wavelength_start_column: 波长开始列名或索引 output_dir: 输出目录(如果为None,使用self.output_dir) wavelength_range: 波长范围(可选,如(374, 1011)) n_groups: 将参数值分成几组进行对比 Returns: 保存的文件路径 """ # 读取数据 df = pd.read_csv(csv_path) # 确定波长开始列 if isinstance(wavelength_start_column, str): try: wavelength_start_idx = df.columns.get_loc(wavelength_start_column) except KeyError: try: wavelength_start_idx = int(wavelength_start_column) except ValueError: raise KeyError( f"未找到波长起始列: {wavelength_start_column!r}" ) from None else: wavelength_start_idx = wavelength_start_column # 获取参数值和光谱数据 param_values = df[parameter_column].values spectrum_data = df.iloc[:, wavelength_start_idx:].values # 获取波长 wavelength_cols = df.columns[wavelength_start_idx:] try: wavelengths = wavelength_cols.astype(float).values except: # 如果列名不是数字,使用索引 wavelengths = np.arange(len(wavelength_cols)) # 过滤波长范围 if wavelength_range: mask = (wavelengths >= wavelength_range[0]) & (wavelengths <= wavelength_range[1]) wavelengths = wavelengths[mask] spectrum_data = spectrum_data[:, mask] # 过滤无效值 valid_mask = ~pd.isna(param_values) & np.all(np.isfinite(spectrum_data), axis=1) param_values = param_values[valid_mask] spectrum_data = spectrum_data[valid_mask] # 将参数值分成n_groups组 param_min, param_max = param_values.min(), param_values.max() group_edges = np.linspace(param_min, param_max, n_groups + 1) group_labels = [f"{group_edges[i]:.2f}-{group_edges[i+1]:.2f}" for i in range(n_groups)] group_indices = np.digitize(param_values, group_edges[1:]) # 创建图形 fig, ax = plt.subplots(figsize=(14, 8)) # 为每组选择颜色 colors = plt.cm.viridis(np.linspace(0, 1, n_groups)) # 绘制每组的光谱曲线 for i in range(n_groups): group_mask = group_indices == i if group_mask.sum() == 0: continue group_spectra = spectrum_data[group_mask] group_mean_spectrum = np.nanmean(group_spectra, axis=0) group_std_spectrum = np.nanstd(group_spectra, axis=0) # 绘制平均光谱 ax.plot(wavelengths, group_mean_spectrum, color=colors[i], linewidth=2.5, label=f'组 {i+1} ({group_labels[i]}, n={group_mask.sum()})') # 绘制标准差阴影 ax.fill_between(wavelengths, group_mean_spectrum - group_std_spectrum, group_mean_spectrum + group_std_spectrum, color=colors[i], alpha=0.2) # 设置图形属性 ax.set_xlabel('波长 (nm)', fontsize=14, fontweight='bold') ax.set_ylabel('光谱反射率', fontsize=14, fontweight='bold') ax.set_title(f'{parameter_column} - 不同参数值的光谱曲线对比', fontsize=16, fontweight='bold') ax.legend(loc='best', fontsize=10) ax.grid(True, alpha=0.3) if wavelength_range: ax.set_xlim(wavelength_range) plt.tight_layout() # 保存图片 if output_dir is None: output_dir = self.output_dir else: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) safe_name = "".join(c for c in parameter_column if c.isalnum() or c in ('-', '_', '.')) output_path = output_dir / f"{safe_name}_spectrum_comparison.png" plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() print(f"光谱曲线图已保存: {output_path}") return str(output_path) def plot_statistical_charts(self, csv_path: str, parameter_columns: List[str], output_dir: Optional[str] = None) -> Dict[str, str]: """ 绘制统计图表:**只针对水质参数列**(数值型,排除波长列) - 水质参数列(如浓度、含量等数值型参数)使用箱线图/直方图/相关性热力图 - 排除光谱波长列(虽然也是数值型,但不是水质参数) Args: csv_path: CSV文件路径 parameter_columns: **水质参数**列名列表(数值型,已排除波长列) output_dir: 输出目录 Returns: 保存的文件路径字典 """ df = pd.read_csv(csv_path) if output_dir is None: output_dir = self.output_dir else: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) output_paths = {} # 水质参数统计图表(针对数值型参数,排除波长列) # 假设传入的 parameter_columns 已经是过滤后的水质参数列 numeric_cols = [col for col in parameter_columns if col in df.columns and pd.api.types.is_numeric_dtype(df[col])] # 1. 箱线图 if len(numeric_cols) > 0: fig, ax = plt.subplots(figsize=(12, 6)) data_for_boxplot = [df[col].dropna() for col in numeric_cols] if data_for_boxplot: ax.boxplot(data_for_boxplot, labels=numeric_cols) ax.set_ylabel('数值', fontsize=12, fontweight='bold') ax.set_title('水质参数箱线图', fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3, axis='y') plt.xticks(rotation=45, ha='right') plt.tight_layout() boxplot_path = output_dir / "parameter_boxplot.png" plt.savefig(boxplot_path, dpi=300, bbox_inches='tight') plt.close() output_paths['boxplot'] = str(boxplot_path) # 2. 直方图 (每个水质参数列) for col in numeric_cols: fig, ax = plt.subplots(figsize=(10, 6)) data = df[col].dropna() if len(data) > 1: ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue') ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold') ax.set_ylabel('频数', fontsize=12, fontweight='bold') ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3, axis='y') # 添加统计信息 mean_val = data.mean() std_val = data.std() if len(data) > 1 else 0 ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}') ax.legend() plt.tight_layout() safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.')) hist_path = output_dir / f"{safe_name}_histogram.png" plt.savefig(hist_path, dpi=300, bbox_inches='tight') plt.close() output_paths[f'histogram_{col}'] = str(hist_path) # 3. 相关性热力图 if len(numeric_cols) >= 2: corr_matrix = df[numeric_cols].corr() fig, ax = plt.subplots(figsize=(10, 8)) sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm', center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8}, ax=ax, vmin=-1, vmax=1) ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold') plt.tight_layout() heatmap_path = output_dir / "correlation_heatmap.png" plt.savefig(heatmap_path, dpi=300, bbox_inches='tight') plt.close() output_paths['heatmap'] = str(heatmap_path) if not output_paths: print("警告: 没有生成任何统计图表(可能无合适的水质参数列)") else: print(f"统计图表已保存到: {output_dir},共 {len(output_paths)} 个文件") return output_paths def plot_distribution_map_enhanced(self, prediction_csv_path: str, boundary_shp_path: str, parameter_column: str = 'prediction', output_path: Optional[str] = None, resolution: float = 30, input_crs: str = 'EPSG:32651', output_crs: str = 'EPSG:4326', colormap: str = 'viridis') -> str: """ 生成增强的含量分布图(彩色填充图) 这是对step9的增强版本,使用更丰富的颜色映射 Args: prediction_csv_path: 预测结果CSV文件路径 boundary_shp_path: 边界shapefile文件路径 parameter_column: 参数值列名 output_path: 输出图片路径 resolution: 插值网格分辨率 input_crs: 输入坐标系 output_crs: 输出坐标系 colormap: 颜色映射方案 Returns: 保存的文件路径 """ from map import ContentMapper # 使用ContentMapper生成分布图 mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs) if output_path is None: csv_name = Path(prediction_csv_path).stem output_path = str(self.output_dir / f"{csv_name}_distribution_enhanced.png") # 处理数据并生成分布图 mapper.process_data( csv_file=prediction_csv_path, shp_file=boundary_shp_path, output_file=output_path, resolution=resolution, show_sample_points=False ) print(f"增强分布图已保存: {output_path}") return output_path def generate_glint_deglint_previews(self, work_dir: str, output_subdir: str = "glint_deglint_previews", generate_glint: bool = True, generate_deglint: bool = True) -> Dict[str, str]: """ 根据工作目录的2_glint和3_deglint文件夹中的文件生成PNG预览图 功能特点: - 2_glint文件夹:单波段二值耀斑掩膜,使用红色高亮显示 - 3_deglint文件夹:多波段去耀斑影像,使用RGB合成显示 - 自动识别文件类型并应用相应的可视化方案 - 输出保存至14_visualization/glint_deglint_previews/ Args: work_dir: 工作目录路径 output_subdir: 输出子目录名称(默认 "glint_deglint_previews") generate_glint: 是否处理2_glint文件夹中的文件 generate_deglint: 是否处理3_deglint文件夹中的文件 Returns: 生成的PNG文件路径字典,键为原始文件名,值为PNG路径 """ if not GDAL_AVAILABLE: print("警告: GDAL未安装,无法生成影像预览图") return {} work_dir_path = Path(work_dir) if not work_dir_path.exists(): print(f"错误: 工作目录不存在: {work_dir}") return {} # 创建输出目录 output_dir = self.output_dir / output_subdir output_dir.mkdir(parents=True, exist_ok=True) preview_paths = {} processed_count = 0 print(f"\n{'='*60}") print("生成耀斑分析影像预览图") print(f"{'='*60}") print(f"输出目录: {output_dir}") # 处理2_glint文件夹 if generate_glint: glint_dir = work_dir_path / "2_glint" if glint_dir.exists(): print(f"正在处理2_glint文件夹: {glint_dir}") glint_previews = self._process_image_folder( glint_dir, output_dir, "glint", preview_paths ) processed_count += len(glint_previews) else: print(f"警告: 2_glint文件夹不存在: {glint_dir}") # 处理3_deglint文件夹 if generate_deglint: deglint_dir = work_dir_path / "3_deglint" if deglint_dir.exists(): print(f"正在处理3_deglint文件夹: {deglint_dir}") deglint_previews = self._process_image_folder( deglint_dir, output_dir, "deglint", preview_paths ) processed_count += len(deglint_previews) else: print(f"警告: 3_deglint文件夹不存在: {deglint_dir}") print(f"\n影像预览图生成完成,共处理 {processed_count} 个文件") print(f"预览图保存至: {output_dir}") return preview_paths def _process_image_folder(self, input_dir: Path, output_dir: Path, folder_type: str, preview_paths: Dict[str, str]) -> Dict[str, str]: """ 处理指定文件夹中的影像文件并生成预览图 Args: input_dir: 输入文件夹路径 output_dir: 输出文件夹路径 folder_type: 文件夹类型 ('glint' 或 'deglint') preview_paths: 存储预览图路径的字典(会原地修改) Returns: 处理后的预览图路径字典 """ if not input_dir.exists(): return {} # 支持的影像文件扩展名 supported_extensions = {'.dat', '.bsq', '.tif', '.tiff', '.bil', '.img'} processed = {} for file_path in input_dir.iterdir(): if file_path.is_file() and file_path.suffix.lower() in supported_extensions: try: png_path = self._generate_image_preview_for_visualization( str(file_path), output_dir, folder_type ) if png_path: preview_paths[file_path.name] = png_path processed[file_path.name] = png_path print(f" ✓ 已生成: {file_path.name} -> {Path(png_path).name}") except Exception as e: print(f" ✗ 处理文件 {file_path.name} 时出错: {e}") return processed def _generate_image_preview_for_visualization(self, img_path: str, output_dir: Path, folder_type: str) -> Optional[str]: """ 为可视化模块生成影像预览图 特别处理: - 耀斑掩膜 (2_glint/*.dat):单波段二值图,黑底(0)、耀斑区域为白(1) - 其他影像:多波段RGB合成,使用波长选择RGB波段 Args: img_path: 输入影像文件路径 output_dir: 输出目录 folder_type: 文件夹类型 ('glint' 或 'deglint') Returns: 生成的PNG文件路径,如果失败则返回None """ try: img_path_obj = Path(img_path) img_name = img_path_obj.stem output_path = output_dir / f"{folder_type}_{img_name}_preview.png" # 如果文件已存在,跳过生成 if output_path.exists(): return str(output_path) # 使用GDAL读取影像 dataset = gdal.Open(img_path) if dataset is None: print(f" 警告: 无法打开影像文件: {img_path}") return None # 获取影像信息 width = dataset.RasterXSize height = dataset.RasterYSize band_count = dataset.RasterCount # 检测是否为单波段二值图(耀斑掩膜) is_binary_mask = (band_count == 1) or (folder_type == 'glint') if is_binary_mask: # 单波段二值图的特殊处理 binary_data = dataset.GetRasterBand(1).ReadAsArray().astype(np.float32) # 单波段二值图 → RGB:耀斑文件夹固定为黑底、耀斑白;其余为灰度拉伸 if folder_type == 'glint': # 背景黑色 (0,0,0),掩膜中大于阈值的像元为耀斑 → 白色 (1,1,1) rgb_image = np.zeros((height, width, 3), dtype=np.float32) glint_mask = binary_data > 0.5 rgb_image[glint_mask] = 255 title_color_info = "背景黑,白色=耀斑区域" else: # 其他单波段:使用灰度 binary_data = binary_data / (binary_data.max() + 1e-10) if binary_data.max() > 0 else binary_data rgb_image = np.stack([binary_data, binary_data, binary_data], axis=2) title_color_info = "灰度显示" else: # 多波段影像的正常处理 # 选择RGB波段 if band_count >= 3: bands = self._select_rgb_bands(img_path, band_count) else: bands = [0, 0, 0] # 灰度显示 # 读取指定波段 r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32) g_data = dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype( np.float32) if band_count > 1 else r_data.copy() b_data = dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype( np.float32) if band_count > 2 else r_data.copy() # 去除无效值 r_data[r_data <= 0] = np.nan if band_count > 1: g_data[g_data <= 0] = np.nan if band_count > 2: b_data[b_data <= 0] = np.nan # 2%线性拉伸 def linear_stretch(data, low_percent=2, high_percent=98): valid_data = data[~np.isnan(data)] if len(valid_data) == 0: return np.zeros_like(data) low_val = np.percentile(valid_data, low_percent) high_val = np.percentile(valid_data, high_percent) if high_val - low_val < 1e-10: return np.zeros_like(data) stretched = (data - low_val) / (high_val - low_val) stretched = np.clip(stretched, 0, 1) return stretched r_stretched = linear_stretch(r_data) g_stretched = linear_stretch(g_data) if band_count > 1 else r_stretched b_stretched = linear_stretch(b_data) if band_count > 2 else r_stretched # 合成为RGB图像 rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2) rgb_image = np.nan_to_num(rgb_image, nan=0.0) # ========== 创建图形,禁用格网 ========== fig, ax = plt.subplots(figsize=(12, 10)) ax.grid(False) # 显式关闭格网 ax.imshow(rgb_image) ax.axis('off') # 可选:关闭坐标轴(连边框都隐藏,更干净) # 或者用 ax.set_axis_off() 效果相同 # 添加影像信息(如果需要,可以取消注释) # ax.set_title(...) plt.tight_layout() plt.savefig(str(output_path), dpi=150, bbox_inches='tight', pad_inches=0.05) plt.close(fig) # 释放GDAL数据集 dataset = None return str(output_path) except Exception as e: print(f" 生成预览图时出错 {img_path}: {e}") plt.close('all') return None def _select_rgb_bands(self, img_path: str, band_count: int) -> List[int]: """ 选择RGB波段(优先使用波长查找,失败则使用默认索引) Args: img_path: 影像文件路径 band_count: 总波段数 Returns: [R, G, B] 波段索引列表 """ try: # 尝试使用pipeline中的find_band_number函数 from src.utils.util import find_band_number target_wavelengths = {'R': 650.0, 'G': 550.0, 'B': 460.0} bands = [] for color, target_wl in target_wavelengths.items(): try: band_idx = find_band_number(target_wl, img_path) band_idx = max(0, min(band_idx, band_count - 1)) bands.append(band_idx) except: # 回退到基于索引的选择 if color == 'R': bands.append(min(band_count - 1, int(band_count * 0.25))) elif color == 'G': bands.append(min(band_count - 1, int(band_count * 0.15))) else: bands.append(min(band_count - 1, int(band_count * 0.05))) return bands if len(bands) == 3 else [int(band_count*0.25), int(band_count*0.15), int(band_count*0.05)] except ImportError: # 如果无法导入,使用基于索引的选择 if band_count >= 3: return [min(band_count - 1, int(band_count * 0.25)), min(band_count - 1, int(band_count * 0.15)), min(band_count - 1, int(band_count * 0.05))] else: return [0, 0, 0] except Exception: return [min(band_count - 1, int(band_count * 0.25)) if band_count > 0 else 0, min(band_count - 1, int(band_count * 0.15)) if band_count > 1 else 0, min(band_count - 1, int(band_count * 0.05)) if band_count > 2 else 0] def generate_sampling_point_map(self, hyperspectral_path: Optional[str] = None, csv_path: Optional[str] = None, output_subdir: str = "sampling_maps") -> str: """ 生成采样点地图 - 在高光谱假彩色影像上标注采样点 Args: hyperspectral_path: 高光谱影像路径(如果为None则自动查找) csv_path: 采样点CSV文件路径(如果为None则自动查找4_processed_data中的CSV) output_subdir: 输出子目录名称 Returns: 生成的地图文件路径 """ try: from src.postprocessing.point_map import SamplingPointMap # 如果没有提供路径,自动查找 work_dir = self.output_dir.parent # 14_visualization的父目录就是工作目录 if hyperspectral_path is None: # 查找高光谱影像 hyperspectral_files = [] for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']: hyperspectral_files.extend(list(work_dir.glob(f"**/{ext}"))) if hyperspectral_files: hyperspectral_path = str(hyperspectral_files[0]) else: print("警告: 未找到高光谱影像文件") return "" if csv_path is None: # 查找4_processed_data中的CSV文件 processed_dir = work_dir / "4_processed_data" if processed_dir.exists(): csv_files = list(processed_dir.glob("*.csv")) if csv_files: csv_path = str(csv_files[0]) else: print(f"警告: 在 {processed_dir} 中未找到CSV文件") return "" else: print("警告: 4_processed_data目录不存在") return "" print(f"生成采样点地图 - 高光谱: {Path(hyperspectral_path).name}, CSV: {Path(csv_path).name}") # 创建采样点地图生成器 map_generator = SamplingPointMap(output_dir=str(self.output_dir / output_subdir)) map_path = map_generator.create_sampling_point_map( hyperspectral_path=hyperspectral_path, csv_path=csv_path, point_color='red', point_size=100, point_alpha=0.9, show_north_arrow=True, show_scale_bar=True, show_legend=True ) print(f"采样点地图已生成: {map_path}") return map_path except Exception as e: print(f"生成采样点地图时出错: {e}") return "" def generate_all_visualizations(self, work_dir: Optional[str] = None) -> Dict[str, str]: """ 生成所有可视化结果,包括掩膜缩略图、采样点地图等 Args: work_dir: 工作目录(如果为None则使用output_dir的父目录) Returns: 生成的文件路径字典 """ if work_dir is None: work_dir = str(self.output_dir.parent) results = {} # 生成掩膜和耀斑缩略图 try: preview_paths = self.generate_glint_deglint_previews(work_dir=work_dir) results['glint_deglint_previews'] = preview_paths except Exception as e: print(f"生成掩膜缩略图时出错: {e}") # 生成采样点地图 try: map_path = self.generate_sampling_point_map() if map_path: results['sampling_map'] = map_path except Exception as e: print(f"生成采样点地图时出错: {e}") # 生成航线图 try: flight_path = self.generate_flight_path_map(work_dir=work_dir) if flight_path: results['flight_path'] = flight_path except Exception as e: print(f"生成航线图时出错: {e}") return results def generate_flight_path_map(self, work_dir: Optional[str] = None, gps_folder: Optional[str] = None, hyperspectral_path: Optional[str] = None, output_subdir: str = "flight_paths") -> str: """ 生成飞行轨迹航线图 - 在高光谱影像上绘制多架次飞行轨迹 Args: work_dir: 工作目录(如果为None则使用output_dir的父目录) gps_folder: GPS数据文件夹路径(如果为None则自动查找) hyperspectral_path: 高光谱影像路径(如果为None则自动查找) output_subdir: 输出子目录名称 Returns: 生成的航线图文件路径 """ try: from src.postprocessing.flight_path import FlightPathVisualizer # 如果没有提供路径,自动查找 if work_dir is None: work_dir = str(self.output_dir.parent) work_path = Path(work_dir) # 查找GPS文件夹 if gps_folder is None: # 首先查找常见的GPS数据文件夹 possible_gps_dirs = ['gps', 'GPS', 'flight', '轨迹', '航线'] for gps_dir_name in possible_gps_dirs: gps_dir = work_path / gps_dir_name if gps_dir.exists() and list(gps_dir.glob("**/*.gps")): gps_folder = str(gps_dir) print(f"找到GPS文件夹: {gps_folder}") break # 如果没找到,查找任何包含.gps文件的文件夹 if gps_folder is None: gps_files = list(work_path.glob("**/*.gps")) if gps_files: gps_folder = str(gps_files[0].parent) print(f"使用包含GPS文件的文件夹: {gps_folder}") if gps_folder is None or not Path(gps_folder).exists(): print("警告: 未找到GPS数据文件夹") return "" # 查找高光谱影像 - 优先使用3_deglint if hyperspectral_path is None: # 首先查找3_deglint文件夹 deglint_dir = work_path / "3_deglint" if deglint_dir.exists(): hyperspectral_files = [] for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']: hyperspectral_files.extend(list(deglint_dir.glob(ext))) if hyperspectral_files: hyperspectral_path = str(hyperspectral_files[0]) print(f"使用3_deglint中的高光谱影像: {hyperspectral_path}") # 如果没找到,再查找整个工作目录 if hyperspectral_path is None: hyperspectral_files = [] for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']: hyperspectral_files.extend(list(work_path.glob(f"**/{ext}"))) if hyperspectral_files: hyperspectral_path = str(hyperspectral_files[0]) print(f"使用找到的高光谱影像: {hyperspectral_path}") if hyperspectral_path is None or not Path(hyperspectral_path).exists(): print("警告: 未找到高光谱影像文件") return "" print(f"生成航线图 - GPS: {Path(gps_folder).name}, 影像: {Path(hyperspectral_path).name}") # 创建航线图生成器 flight_visualizer = FlightPathVisualizer( output_dir=str(self.output_dir / output_subdir) ) map_path = flight_visualizer.create_flight_path_map( gps_folder=gps_folder, hyperspectral_path=hyperspectral_path, line_width=2, show_north_arrow=True, show_scale_bar=True, dpi=300 ) print(f"航线图已生成: {map_path}") return map_path except ImportError as e: print(f"无法导入flight_path模块: {e}") return "" except Exception as e: print(f"生成航线图时出错: {e}") return "" def batch_generate_flight_paths(self, work_dir: Optional[str] = None, gps_parent_folder: Optional[str] = None) -> Dict[str, str]: """ 批量生成多个飞行任务的航线图 Args: work_dir: 工作目录 gps_parent_folder: 包含多个GPS子文件夹的父文件夹 Returns: 生成的航线图文件路径字典 """ try: from src.postprocessing.flight_path import FlightPathVisualizer if work_dir is None: work_dir = str(self.output_dir.parent) work_path = Path(work_dir) # 查找GPS父文件夹 if gps_parent_folder is None: # 查找常见的GPS数据文件夹 possible_gps_dirs = ['gps', 'GPS', 'flight', 'flights', '轨迹', '航线'] for gps_dir_name in possible_gps_dirs: gps_dir = work_path / gps_dir_name if gps_dir.exists(): gps_parent_folder = str(gps_dir) break if gps_parent_folder is None or not Path(gps_parent_folder).exists(): print(f"警告: 未找到GPS数据文件夹: {gps_parent_folder}") return {} # 查找高光谱影像 hyperspectral_files = [] deglint_dir = work_path / "3_deglint" if deglint_dir.exists(): for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']: hyperspectral_files.extend(list(deglint_dir.glob(ext))) if not hyperspectral_files: for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']: hyperspectral_files.extend(list(work_path.glob(f"**/{ext}"))) if not hyperspectral_files: print("警告: 未找到高光谱影像文件") return {} hyperspectral_path = str(hyperspectral_files[0]) print(f"批量生成航线图 - GPS父文件夹: {gps_parent_folder}") # 批量生成 flight_visualizer = FlightPathVisualizer( output_dir=str(self.output_dir / "flight_paths") ) map_paths = flight_visualizer.batch_create_maps( gps_folder=gps_parent_folder, hyperspectral_folder=str(Path(hyperspectral_path).parent), output_subdir="batch_flight_paths" ) print(f"批量航线图生成完成,共生成 {len(map_paths)} 个") return map_paths except Exception as e: print(f"批量生成航线图时出错: {e}") return {} class ReportGenerator: """报告生成类""" def __init__(self, output_dir: str = "./reports"): """ 初始化报告生成类 Args: output_dir: 输出目录 """ self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) def generate_training_summary(self, models_dir: str, output_path: Optional[str] = None) -> str: """ 生成模型训练摘要报告(training_summary.csv) Args: models_dir: 模型保存目录 output_path: 输出路径(如果为None,自动生成) Returns: 保存的文件路径 """ from src.core.modeling.modeling_batch import WaterQualityModelingBatch import joblib modeler = WaterQualityModelingBatch(models_dir) models_path = Path(models_dir) all_results = [] # 递归扫描 *.joblib 和 *.pkl,兼容 artifacts_dir/target_name/ 的所有子目录层级 model_files = list(models_path.rglob("*.joblib")) + list(models_path.rglob("*.pkl")) for model_file in model_files: # 目标参数取直系父目录名(符合 artifacts_dir/target_name/ 结构) target_name = model_file.parent.name stem = model_file.stem # 文件名格式:{safe_target}_{preprocess}_{model_name}.joblib # 使用 split('_', 2) 最多切 3 段,目标 1 段、预处理 1 段、模型 1 段 parts = stem.split('_', 2) preprocess = parts[1] if len(parts) > 1 else 'Unknown' model_name_str = parts[2] if len(parts) > 2 else stem # 尝试从 joblib/pkl 读取元数据,提取性能指标 metrics = {} try: data = joblib.load(model_file) metadata = data.get('metadata', {}) metrics = { 'train_r2': metadata.get('train_r2', 'N/A'), 'test_r2': metadata.get('test_r2', 'N/A'), 'test_rmse': metadata.get('test_rmse', 'N/A'), 'train_rmse': metadata.get('train_rmse', 'N/A'), 'train_mae': metadata.get('train_mae', 'N/A'), 'test_mae': metadata.get('test_mae', 'N/A'), 'cv_mean': metadata.get('cv_mean', 'N/A'), } except Exception: pass # 加载失败时 metrics 保持为空字典,摘要中该列为 N/A all_results.append({ 'target': target_name, 'model_file': str(model_file), 'preprocess': preprocess, 'model': model_name_str, **metrics, }) summary_data = [] for result in all_results: summary_data.append({ '目标参数': result['target'], '预处理方法': result['preprocess'], '模型名称': result['model'], '模型文件': result['model_file'], '训练集R²': result.get('train_r2', 'N/A'), '测试集R²': result.get('test_r2', 'N/A'), '测试集RMSE': result.get('test_rmse', 'N/A'), '训练集RMSE': result.get('train_rmse', 'N/A'), '训练集MAE': result.get('train_mae', 'N/A'), '测试集MAE': result.get('test_mae', 'N/A'), 'CV均值': result.get('cv_mean', 'N/A'), }) if not summary_data: print("警告:未找到模型文件,生成空摘要") summary_data = [{ '目标参数': 'No Data', '预处理方法': 'N/A', '模型名称': 'N/A', '模型文件': 'N/A', '训练集R²': 'N/A', '测试集R²': 'N/A', '测试集RMSE': 'N/A', '训练集RMSE': 'N/A', '训练集MAE': 'N/A', '测试集MAE': 'N/A', 'CV均值': 'N/A', }] df_summary = pd.DataFrame(summary_data) if output_path is None: output_path = self.output_dir / "training_summary.csv" else: output_path = Path(output_path) df_summary.to_csv(output_path, index=False, encoding='utf-8-sig') print(f"训练摘要报告已保存: {output_path}") return str(output_path) def generate_prediction_report(self, prediction_csv_paths: Dict[str, str], output_path: Optional[str] = None) -> str: """ 生成参数反演结果报告(包含预测统计信息) Args: prediction_csv_paths: 预测结果文件路径字典(键为目标参数名) output_path: 输出路径(如果为None,自动生成) Returns: 保存的文件路径 """ report_data = [] for target_name, csv_path in prediction_csv_paths.items(): try: df = pd.read_csv(csv_path) # 假设预测值列名为'prediction'或最后一列 if 'prediction' in df.columns: pred_col = 'prediction' else: pred_col = df.columns[-1] predictions = df[pred_col].dropna() stats = { '目标参数': target_name, '样本数量': len(predictions), '均值': predictions.mean(), '标准差': predictions.std(), '最小值': predictions.min(), '最大值': predictions.max(), '中位数': predictions.median(), '25%分位数': predictions.quantile(0.25), '75%分位数': predictions.quantile(0.75), '文件路径': csv_path } report_data.append(stats) except Exception as e: print(f"处理文件 {csv_path} 时出错: {e}") report_data.append({ '目标参数': target_name, '样本数量': 0, '错误': str(e) }) df_report = pd.DataFrame(report_data) if output_path is None: output_path = self.output_dir / "prediction_report.csv" else: output_path = Path(output_path) df_report.to_csv(output_path, index=False, encoding='utf-8-sig', float_format='%.6f') print(f"预测结果报告已保存: {output_path}") return str(output_path) def generate_batch_inference_summary(self, pipeline_info: Dict, output_path: Optional[str] = None) -> str: """ 生成批量处理摘要(batch_inference_summary.json) Args: pipeline_info: 流程信息字典,包含各步骤的执行情况 output_path: 输出路径(如果为None,自动生成) Returns: 保存的文件路径 """ summary = { '执行时间': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '工作目录': str(pipeline_info.get('work_dir', 'Unknown')), '步骤执行情况': {}, '模型训练': {}, '预测结果': {}, '输出文件': {} } # 添加步骤执行情况 for step in ['step1', 'step2', 'step3', 'step4', 'step5', 'step6', 'step7', 'step8', 'step9']: if step in pipeline_info: summary['步骤执行情况'][step] = { '状态': pipeline_info[step].get('status', 'completed'), '输出文件': pipeline_info[step].get('output_file', 'N/A') } # 添加模型训练信息 if 'models_dir' in pipeline_info: summary['模型训练']['模型目录'] = pipeline_info['models_dir'] summary['模型训练']['训练参数'] = pipeline_info.get('training_params', {}) # 添加预测结果信息 if 'prediction_files' in pipeline_info: summary['预测结果'] = { '目标参数数量': len(pipeline_info['prediction_files']), '预测文件': pipeline_info['prediction_files'] } # 添加输出文件列表 summary['输出文件'] = pipeline_info.get('output_files', {}) if output_path is None: output_path = self.output_dir / "batch_inference_summary.json" else: output_path = Path(output_path) with open(output_path, 'w', encoding='utf-8') as f: json.dump(summary, f, ensure_ascii=False, indent=2) print(f"批量处理摘要已保存: {output_path}") return str(output_path)