1209 lines
49 KiB
Python
1209 lines
49 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
水质参数反演可视化与报告生成模块
|
||
|
||
功能包括:
|
||
1. 散点图:模型评估(真实值vs预测值)
|
||
2. 含量分布图:空间可视化,彩色填充图
|
||
3. 光谱曲线图:不同参数值的光谱曲线对比
|
||
4. 统计图表:箱线图、直方图、相关性热力图等
|
||
5. 模型训练摘要报告:training_summary.csv
|
||
6. 参数反演结果报告:包含预测统计信息
|
||
7. 批量处理摘要:batch_inference_summary.json
|
||
8. 掩膜和耀斑缩略图:2_Glint_Detection和3_deglint文件夹的影像预览图
|
||
"""
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as patches
|
||
import seaborn as sns
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple, Union
|
||
import json
|
||
import warnings
|
||
from datetime import datetime
|
||
import joblib
|
||
|
||
# 导入GDAL用于影像读写
|
||
try:
|
||
from osgeo import gdal
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
print("警告: GDAL未安装,影像预览图生成功能可能无法正常工作")
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# 设置中文字体
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
plt.rcParams['font.size'] = 12
|
||
|
||
# 设置seaborn样式
|
||
sns.set_style("whitegrid")
|
||
sns.set_palette("husl")
|
||
|
||
|
||
class WaterQualityVisualization:
|
||
"""水质参数反演可视化类"""
|
||
|
||
def __init__(self, output_dir: str = "./visualization_output"):
|
||
"""
|
||
初始化可视化类
|
||
|
||
Args:
|
||
output_dir: 输出目录
|
||
"""
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
def plot_scatter_true_vs_pred(self, y_true: np.ndarray, y_pred: np.ndarray,
|
||
target_name: str = "参数",
|
||
train_indices: Optional[np.ndarray] = None,
|
||
test_indices: Optional[np.ndarray] = None,
|
||
metrics: Optional[Dict] = None,
|
||
output_path: Optional[str] = None) -> str:
|
||
"""
|
||
绘制散点图:真实值vs预测值
|
||
|
||
Args:
|
||
y_true: 真实值
|
||
y_pred: 预测值
|
||
target_name: 目标参数名称
|
||
train_indices: 训练集索引(可选)
|
||
test_indices: 测试集索引(可选)
|
||
metrics: 评估指标字典(可选)
|
||
output_path: 输出路径(如果为None,自动生成)
|
||
|
||
Returns:
|
||
保存的文件路径
|
||
"""
|
||
fig, ax = plt.subplots(figsize=(10, 8))
|
||
|
||
# 计算所有数据的R²和RMSE
|
||
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
||
|
||
r2_all = r2_score(y_true, y_pred)
|
||
rmse_all = np.sqrt(mean_squared_error(y_true, y_pred))
|
||
mae_all = mean_absolute_error(y_true, y_pred)
|
||
|
||
# 如果提供了训练/测试集索引,分别绘制
|
||
if train_indices is not None and test_indices is not None:
|
||
y_train_true = y_true[train_indices]
|
||
y_train_pred = y_pred[train_indices]
|
||
y_test_true = y_true[test_indices]
|
||
y_test_pred = y_pred[test_indices]
|
||
|
||
# 绘制训练集散点
|
||
ax.scatter(y_train_true, y_train_pred, alpha=0.6, s=50,
|
||
label=f'训练集 (n={len(y_train_true)})', color='blue', edgecolors='black', linewidths=0.5)
|
||
|
||
# 绘制测试集散点
|
||
ax.scatter(y_test_true, y_test_pred, alpha=0.6, s=50,
|
||
label=f'测试集 (n={len(y_test_true)})', color='red', edgecolors='black', linewidths=0.5)
|
||
|
||
# 计算训练集和测试集指标
|
||
if metrics is None:
|
||
train_r2 = r2_score(y_train_true, y_train_pred)
|
||
test_r2 = r2_score(y_test_true, y_test_pred)
|
||
train_rmse = np.sqrt(mean_squared_error(y_train_true, y_train_pred))
|
||
test_rmse = np.sqrt(mean_squared_error(y_test_true, y_test_pred))
|
||
else:
|
||
train_r2 = metrics.get('train_r2', r2_score(y_train_true, y_train_pred))
|
||
test_r2 = metrics.get('test_r2', r2_score(y_test_true, y_test_pred))
|
||
train_rmse = metrics.get('train_rmse', np.sqrt(mean_squared_error(y_train_true, y_train_pred)))
|
||
test_rmse = metrics.get('test_rmse', np.sqrt(mean_squared_error(y_test_true, y_test_pred)))
|
||
|
||
metrics_text = f'训练集: R² = {train_r2:.4f}, RMSE = {train_rmse:.4f}\n'
|
||
metrics_text += f'测试集: R² = {test_r2:.4f}, RMSE = {test_rmse:.4f}'
|
||
else:
|
||
# 绘制所有数据
|
||
ax.scatter(y_true, y_pred, alpha=0.6, s=50, color='blue',
|
||
edgecolors='black', linewidths=0.5)
|
||
metrics_text = f'R² = {r2_all:.4f}, RMSE = {rmse_all:.4f}, MAE = {mae_all:.4f}'
|
||
|
||
# 绘制1:1线
|
||
min_val = min(y_true.min(), y_pred.min())
|
||
max_val = max(y_true.max(), y_pred.max())
|
||
ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='1:1线')
|
||
|
||
# 设置图形属性
|
||
ax.set_xlabel(f'真实值 ({target_name})', fontsize=14, fontweight='bold')
|
||
ax.set_ylabel(f'预测值 ({target_name})', fontsize=14, fontweight='bold')
|
||
ax.set_title(f'{target_name} - 真实值 vs 预测值', fontsize=16, fontweight='bold')
|
||
ax.legend(loc='upper left', fontsize=11)
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
# 添加指标文本框
|
||
ax.text(0.05, 0.95, metrics_text, transform=ax.transAxes,
|
||
verticalalignment='top', bbox=dict(boxstyle='round',
|
||
facecolor='wheat', alpha=0.8), fontsize=10)
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存图片
|
||
if output_path is None:
|
||
safe_name = "".join(c for c in target_name if c.isalnum() or c in ('-', '_', '.'))
|
||
output_path = self.output_dir / f"{safe_name}_scatter_true_vs_pred.png"
|
||
else:
|
||
output_path = Path(output_path)
|
||
|
||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f"散点图已保存: {output_path}")
|
||
return str(output_path)
|
||
|
||
def plot_spectrum_by_parameter(self, csv_path: str, parameter_column: str,
|
||
wavelength_start_column: Union[str, int] = "UTM_Y",
|
||
output_dir: Optional[str] = None,
|
||
wavelength_range: Optional[Tuple[float, float]] = None,
|
||
n_groups: int = 5) -> str:
|
||
"""
|
||
绘制光谱曲线图:不同参数值的光谱曲线对比
|
||
|
||
Args:
|
||
csv_path: 包含光谱和参数值的CSV文件路径
|
||
parameter_column: 参数值列名或索引
|
||
wavelength_start_column: 波长开始列名或索引
|
||
output_dir: 输出目录(如果为None,使用self.output_dir)
|
||
wavelength_range: 波长范围(可选,如(374, 1011))
|
||
n_groups: 将参数值分成几组进行对比
|
||
|
||
Returns:
|
||
保存的文件路径
|
||
"""
|
||
# 读取数据
|
||
df = pd.read_csv(csv_path)
|
||
|
||
# 确定波长开始列
|
||
if isinstance(wavelength_start_column, str):
|
||
try:
|
||
wavelength_start_idx = df.columns.get_loc(wavelength_start_column)
|
||
except KeyError:
|
||
try:
|
||
wavelength_start_idx = int(wavelength_start_column)
|
||
except ValueError:
|
||
raise KeyError(
|
||
f"未找到波长起始列: {wavelength_start_column!r}"
|
||
) from None
|
||
else:
|
||
wavelength_start_idx = wavelength_start_column
|
||
|
||
# 获取参数值和光谱数据
|
||
param_values = df[parameter_column].values
|
||
spectrum_data = df.iloc[:, wavelength_start_idx:].values
|
||
|
||
# 获取波长
|
||
wavelength_cols = df.columns[wavelength_start_idx:]
|
||
try:
|
||
wavelengths = wavelength_cols.astype(float).values
|
||
except:
|
||
# 如果列名不是数字,使用索引
|
||
wavelengths = np.arange(len(wavelength_cols))
|
||
|
||
# 过滤波长范围
|
||
if wavelength_range:
|
||
mask = (wavelengths >= wavelength_range[0]) & (wavelengths <= wavelength_range[1])
|
||
wavelengths = wavelengths[mask]
|
||
spectrum_data = spectrum_data[:, mask]
|
||
|
||
# 过滤无效值
|
||
valid_mask = ~pd.isna(param_values) & np.all(np.isfinite(spectrum_data), axis=1)
|
||
param_values = param_values[valid_mask]
|
||
spectrum_data = spectrum_data[valid_mask]
|
||
|
||
# 将参数值分成n_groups组
|
||
param_min, param_max = param_values.min(), param_values.max()
|
||
group_edges = np.linspace(param_min, param_max, n_groups + 1)
|
||
group_labels = [f"{group_edges[i]:.2f}-{group_edges[i+1]:.2f}"
|
||
for i in range(n_groups)]
|
||
group_indices = np.digitize(param_values, group_edges[1:])
|
||
|
||
# 创建图形
|
||
fig, ax = plt.subplots(figsize=(14, 8))
|
||
|
||
# 为每组选择颜色
|
||
colors = plt.cm.viridis(np.linspace(0, 1, n_groups))
|
||
|
||
# 绘制每组的光谱曲线
|
||
for i in range(n_groups):
|
||
group_mask = group_indices == i
|
||
if group_mask.sum() == 0:
|
||
continue
|
||
|
||
group_spectra = spectrum_data[group_mask]
|
||
group_mean_spectrum = np.nanmean(group_spectra, axis=0)
|
||
group_std_spectrum = np.nanstd(group_spectra, axis=0)
|
||
|
||
# 绘制平均光谱
|
||
ax.plot(wavelengths, group_mean_spectrum,
|
||
color=colors[i], linewidth=2.5,
|
||
label=f'组 {i+1} ({group_labels[i]}, n={group_mask.sum()})')
|
||
|
||
# 绘制标准差阴影
|
||
ax.fill_between(wavelengths,
|
||
group_mean_spectrum - group_std_spectrum,
|
||
group_mean_spectrum + group_std_spectrum,
|
||
color=colors[i], alpha=0.2)
|
||
|
||
# 设置图形属性
|
||
ax.set_xlabel('波长 (nm)', fontsize=14, fontweight='bold')
|
||
ax.set_ylabel('光谱反射率', fontsize=14, fontweight='bold')
|
||
ax.set_title(f'{parameter_column} - 不同参数值的光谱曲线对比',
|
||
fontsize=16, fontweight='bold')
|
||
ax.legend(loc='best', fontsize=10)
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
if wavelength_range:
|
||
ax.set_xlim(wavelength_range)
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存图片
|
||
if output_dir is None:
|
||
output_dir = self.output_dir
|
||
else:
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
safe_name = "".join(c for c in parameter_column if c.isalnum() or c in ('-', '_', '.'))
|
||
output_path = output_dir / f"{safe_name}_spectrum_comparison.png"
|
||
|
||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f"光谱曲线图已保存: {output_path}")
|
||
return str(output_path)
|
||
|
||
def plot_statistical_charts(self, csv_path: str, parameter_columns: List[str],
|
||
output_dir: Optional[str] = None) -> Dict[str, str]:
|
||
"""
|
||
绘制统计图表:**只针对水质参数列**(数值型,排除波长列)
|
||
- 水质参数列(如浓度、含量等数值型参数)使用箱线图/直方图/相关性热力图
|
||
- 排除光谱波长列(虽然也是数值型,但不是水质参数)
|
||
|
||
Args:
|
||
csv_path: CSV文件路径
|
||
parameter_columns: **水质参数**列名列表(数值型,已排除波长列)
|
||
output_dir: 输出目录
|
||
|
||
Returns:
|
||
保存的文件路径字典
|
||
"""
|
||
df = pd.read_csv(csv_path)
|
||
|
||
if output_dir is None:
|
||
output_dir = self.output_dir
|
||
else:
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
output_paths = {}
|
||
|
||
# 水质参数统计图表(针对数值型参数,排除波长列)
|
||
# 假设传入的 parameter_columns 已经是过滤后的水质参数列
|
||
numeric_cols = [col for col in parameter_columns if col in df.columns and pd.api.types.is_numeric_dtype(df[col])]
|
||
|
||
# 1. 箱线图
|
||
if len(numeric_cols) > 0:
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
data_for_boxplot = [df[col].dropna() for col in numeric_cols]
|
||
if data_for_boxplot:
|
||
ax.boxplot(data_for_boxplot, labels=numeric_cols)
|
||
ax.set_ylabel('数值', fontsize=12, fontweight='bold')
|
||
ax.set_title('水质参数箱线图', fontsize=14, fontweight='bold')
|
||
ax.grid(True, alpha=0.3, axis='y')
|
||
plt.xticks(rotation=45, ha='right')
|
||
plt.tight_layout()
|
||
|
||
boxplot_path = output_dir / "parameter_boxplot.png"
|
||
plt.savefig(boxplot_path, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
output_paths['boxplot'] = str(boxplot_path)
|
||
|
||
# 2. 直方图 (每个水质参数列)
|
||
for col in numeric_cols:
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
data = df[col].dropna()
|
||
if len(data) > 1:
|
||
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
|
||
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
|
||
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
|
||
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
|
||
ax.grid(True, alpha=0.3, axis='y')
|
||
|
||
# 添加统计信息
|
||
mean_val = data.mean()
|
||
std_val = data.std() if len(data) > 1 else 0
|
||
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
|
||
ax.legend()
|
||
|
||
plt.tight_layout()
|
||
|
||
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
|
||
hist_path = output_dir / f"{safe_name}_histogram.png"
|
||
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
output_paths[f'histogram_{col}'] = str(hist_path)
|
||
|
||
# 3. 相关性热力图
|
||
if len(numeric_cols) >= 2:
|
||
corr_matrix = df[numeric_cols].corr()
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 8))
|
||
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
|
||
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
|
||
ax=ax, vmin=-1, vmax=1)
|
||
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
|
||
heatmap_path = output_dir / "correlation_heatmap.png"
|
||
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
output_paths['heatmap'] = str(heatmap_path)
|
||
|
||
if not output_paths:
|
||
print("警告: 没有生成任何统计图表(可能无合适的水质参数列)")
|
||
else:
|
||
print(f"统计图表已保存到: {output_dir},共 {len(output_paths)} 个文件")
|
||
return output_paths
|
||
|
||
def plot_distribution_map_enhanced(self, prediction_csv_path: str,
|
||
boundary_shp_path: str,
|
||
parameter_column: str = 'prediction',
|
||
output_path: Optional[str] = None,
|
||
resolution: float = 30,
|
||
input_crs: str = 'EPSG:32651',
|
||
output_crs: str = 'EPSG:4326',
|
||
colormap: str = 'viridis') -> str:
|
||
"""
|
||
生成增强的含量分布图(彩色填充图)
|
||
|
||
这是对step9的增强版本,使用更丰富的颜色映射
|
||
|
||
Args:
|
||
prediction_csv_path: 预测结果CSV文件路径
|
||
boundary_shp_path: 边界shapefile文件路径
|
||
parameter_column: 参数值列名
|
||
output_path: 输出图片路径
|
||
resolution: 插值网格分辨率
|
||
input_crs: 输入坐标系
|
||
output_crs: 输出坐标系
|
||
colormap: 颜色映射方案
|
||
|
||
Returns:
|
||
保存的文件路径
|
||
"""
|
||
from map import ContentMapper
|
||
|
||
# 使用ContentMapper生成分布图
|
||
mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs)
|
||
|
||
if output_path is None:
|
||
csv_name = Path(prediction_csv_path).stem
|
||
output_path = str(self.output_dir / f"{csv_name}_distribution_enhanced.png")
|
||
|
||
# 处理数据并生成分布图
|
||
mapper.process_data(
|
||
csv_file=prediction_csv_path,
|
||
shp_file=boundary_shp_path,
|
||
output_file=output_path,
|
||
resolution=resolution,
|
||
show_sample_points=False
|
||
)
|
||
|
||
print(f"增强分布图已保存: {output_path}")
|
||
return output_path
|
||
|
||
def generate_glint_deglint_previews(self, work_dir: str,
|
||
output_subdir: str = "glint_deglint_previews",
|
||
generate_glint: bool = True,
|
||
generate_deglint: bool = True) -> Dict[str, str]:
|
||
"""
|
||
根据工作目录的2_Glint_Detection和3_deglint文件夹中的文件生成PNG预览图
|
||
|
||
功能特点:
|
||
- 2_Glint_Detection文件夹:单波段二值耀斑掩膜,使用红色高亮显示
|
||
- 3_deglint文件夹:多波段去耀斑影像,使用RGB合成显示
|
||
- 自动识别文件类型并应用相应的可视化方案
|
||
- 输出保存至14_visualization/glint_deglint_previews/
|
||
|
||
Args:
|
||
work_dir: 工作目录路径
|
||
output_subdir: 输出子目录名称(默认 "glint_deglint_previews")
|
||
generate_glint: 是否处理2_Glint_Detection文件夹中的文件
|
||
generate_deglint: 是否处理3_deglint文件夹中的文件
|
||
|
||
Returns:
|
||
生成的PNG文件路径字典,键为原始文件名,值为PNG路径
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
print("警告: GDAL未安装,无法生成影像预览图")
|
||
return {}
|
||
|
||
work_dir_path = Path(work_dir)
|
||
if not work_dir_path.exists():
|
||
print(f"错误: 工作目录不存在: {work_dir}")
|
||
return {}
|
||
|
||
# 创建输出目录
|
||
output_dir = self.output_dir / output_subdir
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
preview_paths = {}
|
||
processed_count = 0
|
||
|
||
print(f"\n{'='*60}")
|
||
print("生成耀斑分析影像预览图")
|
||
print(f"{'='*60}")
|
||
print(f"输出目录: {output_dir}")
|
||
|
||
# 处理2_Glint_Detection文件夹
|
||
if generate_glint:
|
||
glint_dir = work_dir_path / "2_Glint_Detection"
|
||
if glint_dir.exists():
|
||
print(f"正在处理2_Glint_Detection文件夹: {glint_dir}")
|
||
glint_previews = self._process_image_folder(
|
||
glint_dir, output_dir, "glint", preview_paths
|
||
)
|
||
processed_count += len(glint_previews)
|
||
else:
|
||
print(f"警告: 2_Glint_Detection文件夹不存在: {glint_dir}")
|
||
|
||
# 处理3_deglint文件夹
|
||
if generate_deglint:
|
||
deglint_dir = work_dir_path / "3_deglint"
|
||
if deglint_dir.exists():
|
||
print(f"正在处理3_deglint文件夹: {deglint_dir}")
|
||
deglint_previews = self._process_image_folder(
|
||
deglint_dir, output_dir, "deglint", preview_paths
|
||
)
|
||
processed_count += len(deglint_previews)
|
||
else:
|
||
print(f"警告: 3_deglint文件夹不存在: {deglint_dir}")
|
||
|
||
print(f"\n影像预览图生成完成,共处理 {processed_count} 个文件")
|
||
print(f"预览图保存至: {output_dir}")
|
||
|
||
return preview_paths
|
||
|
||
def _process_image_folder(self, input_dir: Path, output_dir: Path,
|
||
folder_type: str, preview_paths: Dict[str, str]) -> Dict[str, str]:
|
||
"""
|
||
处理指定文件夹中的影像文件并生成预览图
|
||
|
||
Args:
|
||
input_dir: 输入文件夹路径
|
||
output_dir: 输出文件夹路径
|
||
folder_type: 文件夹类型 ('glint' 或 'deglint')
|
||
preview_paths: 存储预览图路径的字典(会原地修改)
|
||
|
||
Returns:
|
||
处理后的预览图路径字典
|
||
"""
|
||
if not input_dir.exists():
|
||
return {}
|
||
|
||
# 支持的影像文件扩展名
|
||
supported_extensions = {'.dat', '.bsq', '.tif', '.tiff', '.bil', '.img'}
|
||
|
||
processed = {}
|
||
|
||
for file_path in input_dir.iterdir():
|
||
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
|
||
try:
|
||
png_path = self._generate_image_preview_for_visualization(
|
||
str(file_path), output_dir, folder_type
|
||
)
|
||
|
||
if png_path:
|
||
preview_paths[file_path.name] = png_path
|
||
processed[file_path.name] = png_path
|
||
print(f" ✓ 已生成: {file_path.name} -> {Path(png_path).name}")
|
||
|
||
except Exception as e:
|
||
print(f" ✗ 处理文件 {file_path.name} 时出错: {e}")
|
||
|
||
return processed
|
||
|
||
def _generate_image_preview_for_visualization(self, img_path: str,
|
||
output_dir: Path,
|
||
folder_type: str) -> Optional[str]:
|
||
"""
|
||
为可视化模块生成影像预览图
|
||
|
||
特别处理:
|
||
- 耀斑掩膜 (2_Glint_Detection/*.dat):单波段二值图,黑底(0)、耀斑区域为白(1)
|
||
- 其他影像:多波段RGB合成,使用波长选择RGB波段
|
||
|
||
Args:
|
||
img_path: 输入影像文件路径
|
||
output_dir: 输出目录
|
||
folder_type: 文件夹类型 ('glint' 或 'deglint')
|
||
|
||
Returns:
|
||
生成的PNG文件路径,如果失败则返回None
|
||
"""
|
||
try:
|
||
img_path_obj = Path(img_path)
|
||
img_name = img_path_obj.stem
|
||
output_path = output_dir / f"{folder_type}_{img_name}_preview.png"
|
||
|
||
# 如果文件已存在,跳过生成
|
||
if output_path.exists():
|
||
return str(output_path)
|
||
|
||
# 使用GDAL读取影像
|
||
dataset = gdal.Open(img_path)
|
||
if dataset is None:
|
||
print(f" 警告: 无法打开影像文件: {img_path}")
|
||
return None
|
||
|
||
# 获取影像信息
|
||
width = dataset.RasterXSize
|
||
height = dataset.RasterYSize
|
||
band_count = dataset.RasterCount
|
||
|
||
# 检测是否为单波段二值图(耀斑掩膜)
|
||
is_binary_mask = (band_count == 1) or (folder_type == 'glint')
|
||
|
||
if is_binary_mask:
|
||
# 单波段二值图的特殊处理
|
||
binary_data = dataset.GetRasterBand(1).ReadAsArray().astype(np.float32)
|
||
|
||
# 单波段二值图 → RGB:耀斑文件夹固定为黑底、耀斑白;其余为灰度拉伸
|
||
if folder_type == 'glint':
|
||
# 背景黑色 (0,0,0),掩膜中大于阈值的像元为耀斑 → 白色 (1,1,1)
|
||
rgb_image = np.zeros((height, width, 3), dtype=np.float32)
|
||
glint_mask = binary_data > 0.5
|
||
rgb_image[glint_mask] = 255
|
||
title_color_info = "背景黑,白色=耀斑区域"
|
||
else:
|
||
# 其他单波段:使用灰度
|
||
binary_data = binary_data / (binary_data.max() + 1e-10) if binary_data.max() > 0 else binary_data
|
||
rgb_image = np.stack([binary_data, binary_data, binary_data], axis=2)
|
||
title_color_info = "灰度显示"
|
||
else:
|
||
# 多波段影像的正常处理
|
||
# 选择RGB波段
|
||
if band_count >= 3:
|
||
bands = self._select_rgb_bands(img_path, band_count)
|
||
else:
|
||
bands = [0, 0, 0] # 灰度显示
|
||
|
||
# 读取指定波段
|
||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||
g_data = dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(
|
||
np.float32) if band_count > 1 else r_data.copy()
|
||
b_data = dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(
|
||
np.float32) if band_count > 2 else r_data.copy()
|
||
|
||
# 去除无效值
|
||
r_data[r_data <= 0] = np.nan
|
||
if band_count > 1:
|
||
g_data[g_data <= 0] = np.nan
|
||
if band_count > 2:
|
||
b_data[b_data <= 0] = np.nan
|
||
|
||
# 2%线性拉伸
|
||
def linear_stretch(data, low_percent=2, high_percent=98):
|
||
valid_data = data[~np.isnan(data)]
|
||
if len(valid_data) == 0:
|
||
return np.zeros_like(data)
|
||
|
||
low_val = np.percentile(valid_data, low_percent)
|
||
high_val = np.percentile(valid_data, high_percent)
|
||
|
||
if high_val - low_val < 1e-10:
|
||
return np.zeros_like(data)
|
||
|
||
stretched = (data - low_val) / (high_val - low_val)
|
||
stretched = np.clip(stretched, 0, 1)
|
||
return stretched
|
||
|
||
r_stretched = linear_stretch(r_data)
|
||
g_stretched = linear_stretch(g_data) if band_count > 1 else r_stretched
|
||
b_stretched = linear_stretch(b_data) if band_count > 2 else r_stretched
|
||
|
||
# 合成为RGB图像
|
||
rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2)
|
||
rgb_image = np.nan_to_num(rgb_image, nan=0.0)
|
||
|
||
# ========== 创建图形,禁用格网 ==========
|
||
fig, ax = plt.subplots(figsize=(12, 10))
|
||
ax.grid(False) # 显式关闭格网
|
||
ax.imshow(rgb_image)
|
||
ax.axis('off') # 可选:关闭坐标轴(连边框都隐藏,更干净)
|
||
# 或者用 ax.set_axis_off() 效果相同
|
||
|
||
# 添加影像信息(如果需要,可以取消注释)
|
||
# ax.set_title(...)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(str(output_path), dpi=150, bbox_inches='tight', pad_inches=0.05)
|
||
plt.close(fig)
|
||
|
||
# 释放GDAL数据集
|
||
dataset = None
|
||
|
||
return str(output_path)
|
||
|
||
except Exception as e:
|
||
print(f" 生成预览图时出错 {img_path}: {e}")
|
||
plt.close('all')
|
||
return None
|
||
def _select_rgb_bands(self, img_path: str, band_count: int) -> List[int]:
|
||
"""
|
||
选择RGB波段(优先使用波长查找,失败则使用默认索引)
|
||
|
||
Args:
|
||
img_path: 影像文件路径
|
||
band_count: 总波段数
|
||
|
||
Returns:
|
||
[R, G, B] 波段索引列表
|
||
"""
|
||
try:
|
||
# 尝试使用pipeline中的find_band_number函数
|
||
from src.utils.util import find_band_number
|
||
target_wavelengths = {'R': 650.0, 'G': 550.0, 'B': 460.0}
|
||
bands = []
|
||
|
||
for color, target_wl in target_wavelengths.items():
|
||
try:
|
||
band_idx = find_band_number(target_wl, img_path)
|
||
band_idx = max(0, min(band_idx, band_count - 1))
|
||
bands.append(band_idx)
|
||
except:
|
||
# 回退到基于索引的选择
|
||
if color == 'R':
|
||
bands.append(min(band_count - 1, int(band_count * 0.25)))
|
||
elif color == 'G':
|
||
bands.append(min(band_count - 1, int(band_count * 0.15)))
|
||
else:
|
||
bands.append(min(band_count - 1, int(band_count * 0.05)))
|
||
|
||
return bands if len(bands) == 3 else [int(band_count*0.25), int(band_count*0.15), int(band_count*0.05)]
|
||
|
||
except ImportError:
|
||
# 如果无法导入,使用基于索引的选择
|
||
if band_count >= 3:
|
||
return [min(band_count - 1, int(band_count * 0.25)),
|
||
min(band_count - 1, int(band_count * 0.15)),
|
||
min(band_count - 1, int(band_count * 0.05))]
|
||
else:
|
||
return [0, 0, 0]
|
||
except Exception:
|
||
return [min(band_count - 1, int(band_count * 0.25)) if band_count > 0 else 0,
|
||
min(band_count - 1, int(band_count * 0.15)) if band_count > 1 else 0,
|
||
min(band_count - 1, int(band_count * 0.05)) if band_count > 2 else 0]
|
||
|
||
def generate_sampling_point_map(self, hyperspectral_path: Optional[str] = None,
|
||
csv_path: Optional[str] = None,
|
||
output_subdir: str = "sampling_maps") -> str:
|
||
"""
|
||
生成采样点地图 - 在高光谱假彩色影像上标注采样点
|
||
|
||
Args:
|
||
hyperspectral_path: 高光谱影像路径(如果为None则自动查找)
|
||
csv_path: 采样点CSV文件路径(如果为None则自动查找4_processed_data中的CSV)
|
||
output_subdir: 输出子目录名称
|
||
|
||
Returns:
|
||
生成的地图文件路径
|
||
"""
|
||
try:
|
||
from src.postprocessing.point_map import SamplingPointMap
|
||
|
||
# 如果没有提供路径,自动查找
|
||
work_dir = self.output_dir.parent # 14_visualization的父目录就是工作目录
|
||
|
||
if hyperspectral_path is None:
|
||
# 查找高光谱影像
|
||
hyperspectral_files = []
|
||
for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']:
|
||
hyperspectral_files.extend(list(work_dir.glob(f"**/{ext}")))
|
||
if hyperspectral_files:
|
||
hyperspectral_path = str(hyperspectral_files[0])
|
||
else:
|
||
print("警告: 未找到高光谱影像文件")
|
||
return ""
|
||
|
||
if csv_path is None:
|
||
# 查找4_processed_data中的CSV文件
|
||
processed_dir = work_dir / "4_processed_data"
|
||
if processed_dir.exists():
|
||
csv_files = list(processed_dir.glob("*.csv"))
|
||
if csv_files:
|
||
csv_path = str(csv_files[0])
|
||
else:
|
||
print(f"警告: 在 {processed_dir} 中未找到CSV文件")
|
||
return ""
|
||
else:
|
||
print("警告: 4_processed_data目录不存在")
|
||
return ""
|
||
|
||
print(f"生成采样点地图 - 高光谱: {Path(hyperspectral_path).name}, CSV: {Path(csv_path).name}")
|
||
|
||
# 创建采样点地图生成器
|
||
map_generator = SamplingPointMap(output_dir=str(self.output_dir / output_subdir))
|
||
|
||
map_path = map_generator.create_sampling_point_map(
|
||
hyperspectral_path=hyperspectral_path,
|
||
csv_path=csv_path,
|
||
point_color='red',
|
||
point_size=100,
|
||
point_alpha=0.9,
|
||
show_north_arrow=True,
|
||
show_scale_bar=True,
|
||
show_legend=True
|
||
)
|
||
|
||
print(f"采样点地图已生成: {map_path}")
|
||
return map_path
|
||
|
||
except Exception as e:
|
||
print(f"生成采样点地图时出错: {e}")
|
||
return ""
|
||
|
||
def generate_all_visualizations(self, work_dir: Optional[str] = None) -> Dict[str, str]:
|
||
"""
|
||
生成所有可视化结果,包括掩膜缩略图、采样点地图等
|
||
|
||
Args:
|
||
work_dir: 工作目录(如果为None则使用output_dir的父目录)
|
||
|
||
Returns:
|
||
生成的文件路径字典
|
||
"""
|
||
if work_dir is None:
|
||
work_dir = str(self.output_dir.parent)
|
||
|
||
results = {}
|
||
|
||
# 生成掩膜和耀斑缩略图
|
||
try:
|
||
preview_paths = self.generate_glint_deglint_previews(work_dir=work_dir)
|
||
results['glint_deglint_previews'] = preview_paths
|
||
except Exception as e:
|
||
print(f"生成掩膜缩略图时出错: {e}")
|
||
|
||
# 生成采样点地图
|
||
try:
|
||
map_path = self.generate_sampling_point_map()
|
||
if map_path:
|
||
results['sampling_map'] = map_path
|
||
except Exception as e:
|
||
print(f"生成采样点地图时出错: {e}")
|
||
|
||
# 生成航线图
|
||
try:
|
||
flight_path = self.generate_flight_path_map(work_dir=work_dir)
|
||
if flight_path:
|
||
results['flight_path'] = flight_path
|
||
except Exception as e:
|
||
print(f"生成航线图时出错: {e}")
|
||
|
||
return results
|
||
|
||
def generate_flight_path_map(self,
|
||
work_dir: Optional[str] = None,
|
||
gps_folder: Optional[str] = None,
|
||
hyperspectral_path: Optional[str] = None,
|
||
output_subdir: str = "flight_paths") -> str:
|
||
"""
|
||
生成飞行轨迹航线图 - 在高光谱影像上绘制多架次飞行轨迹
|
||
|
||
Args:
|
||
work_dir: 工作目录(如果为None则使用output_dir的父目录)
|
||
gps_folder: GPS数据文件夹路径(如果为None则自动查找)
|
||
hyperspectral_path: 高光谱影像路径(如果为None则自动查找)
|
||
output_subdir: 输出子目录名称
|
||
|
||
Returns:
|
||
生成的航线图文件路径
|
||
"""
|
||
try:
|
||
from src.postprocessing.flight_path import FlightPathVisualizer
|
||
|
||
# 如果没有提供路径,自动查找
|
||
if work_dir is None:
|
||
work_dir = str(self.output_dir.parent)
|
||
work_path = Path(work_dir)
|
||
|
||
# 查找GPS文件夹
|
||
if gps_folder is None:
|
||
# 首先查找常见的GPS数据文件夹
|
||
possible_gps_dirs = ['gps', 'GPS', 'flight', '轨迹', '航线']
|
||
for gps_dir_name in possible_gps_dirs:
|
||
gps_dir = work_path / gps_dir_name
|
||
if gps_dir.exists() and list(gps_dir.glob("**/*.gps")):
|
||
gps_folder = str(gps_dir)
|
||
print(f"找到GPS文件夹: {gps_folder}")
|
||
break
|
||
|
||
# 如果没找到,查找任何包含.gps文件的文件夹
|
||
if gps_folder is None:
|
||
gps_files = list(work_path.glob("**/*.gps"))
|
||
if gps_files:
|
||
gps_folder = str(gps_files[0].parent)
|
||
print(f"使用包含GPS文件的文件夹: {gps_folder}")
|
||
|
||
if gps_folder is None or not Path(gps_folder).exists():
|
||
print("警告: 未找到GPS数据文件夹")
|
||
return ""
|
||
|
||
# 查找高光谱影像 - 优先使用3_deglint
|
||
if hyperspectral_path is None:
|
||
# 首先查找3_deglint文件夹
|
||
deglint_dir = work_path / "3_deglint"
|
||
if deglint_dir.exists():
|
||
hyperspectral_files = []
|
||
for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']:
|
||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||
if hyperspectral_files:
|
||
hyperspectral_path = str(hyperspectral_files[0])
|
||
print(f"使用3_deglint中的高光谱影像: {hyperspectral_path}")
|
||
|
||
# 如果没找到,再查找整个工作目录
|
||
if hyperspectral_path is None:
|
||
hyperspectral_files = []
|
||
for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']:
|
||
hyperspectral_files.extend(list(work_path.glob(f"**/{ext}")))
|
||
if hyperspectral_files:
|
||
hyperspectral_path = str(hyperspectral_files[0])
|
||
print(f"使用找到的高光谱影像: {hyperspectral_path}")
|
||
|
||
if hyperspectral_path is None or not Path(hyperspectral_path).exists():
|
||
print("警告: 未找到高光谱影像文件")
|
||
return ""
|
||
|
||
print(f"生成航线图 - GPS: {Path(gps_folder).name}, 影像: {Path(hyperspectral_path).name}")
|
||
|
||
# 创建航线图生成器
|
||
flight_visualizer = FlightPathVisualizer(
|
||
output_dir=str(self.output_dir / output_subdir)
|
||
)
|
||
|
||
map_path = flight_visualizer.create_flight_path_map(
|
||
gps_folder=gps_folder,
|
||
hyperspectral_path=hyperspectral_path,
|
||
line_width=2,
|
||
show_north_arrow=True,
|
||
show_scale_bar=True,
|
||
dpi=300
|
||
)
|
||
|
||
print(f"航线图已生成: {map_path}")
|
||
return map_path
|
||
|
||
except ImportError as e:
|
||
print(f"无法导入flight_path模块: {e}")
|
||
return ""
|
||
except Exception as e:
|
||
print(f"生成航线图时出错: {e}")
|
||
return ""
|
||
|
||
def batch_generate_flight_paths(self,
|
||
work_dir: Optional[str] = None,
|
||
gps_parent_folder: Optional[str] = None) -> Dict[str, str]:
|
||
"""
|
||
批量生成多个飞行任务的航线图
|
||
|
||
Args:
|
||
work_dir: 工作目录
|
||
gps_parent_folder: 包含多个GPS子文件夹的父文件夹
|
||
|
||
Returns:
|
||
生成的航线图文件路径字典
|
||
"""
|
||
try:
|
||
from src.postprocessing.flight_path import FlightPathVisualizer
|
||
|
||
if work_dir is None:
|
||
work_dir = str(self.output_dir.parent)
|
||
work_path = Path(work_dir)
|
||
|
||
# 查找GPS父文件夹
|
||
if gps_parent_folder is None:
|
||
# 查找常见的GPS数据文件夹
|
||
possible_gps_dirs = ['gps', 'GPS', 'flight', 'flights', '轨迹', '航线']
|
||
for gps_dir_name in possible_gps_dirs:
|
||
gps_dir = work_path / gps_dir_name
|
||
if gps_dir.exists():
|
||
gps_parent_folder = str(gps_dir)
|
||
break
|
||
|
||
if gps_parent_folder is None or not Path(gps_parent_folder).exists():
|
||
print(f"警告: 未找到GPS数据文件夹: {gps_parent_folder}")
|
||
return {}
|
||
|
||
# 查找高光谱影像
|
||
hyperspectral_files = []
|
||
deglint_dir = work_path / "3_deglint"
|
||
if deglint_dir.exists():
|
||
for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']:
|
||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||
|
||
if not hyperspectral_files:
|
||
for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']:
|
||
hyperspectral_files.extend(list(work_path.glob(f"**/{ext}")))
|
||
|
||
if not hyperspectral_files:
|
||
print("警告: 未找到高光谱影像文件")
|
||
return {}
|
||
|
||
hyperspectral_path = str(hyperspectral_files[0])
|
||
|
||
print(f"批量生成航线图 - GPS父文件夹: {gps_parent_folder}")
|
||
|
||
# 批量生成
|
||
flight_visualizer = FlightPathVisualizer(
|
||
output_dir=str(self.output_dir / "flight_paths")
|
||
)
|
||
|
||
map_paths = flight_visualizer.batch_create_maps(
|
||
gps_folder=gps_parent_folder,
|
||
hyperspectral_folder=str(Path(hyperspectral_path).parent),
|
||
output_subdir="batch_flight_paths"
|
||
)
|
||
|
||
print(f"批量航线图生成完成,共生成 {len(map_paths)} 个")
|
||
return map_paths
|
||
|
||
except Exception as e:
|
||
print(f"批量生成航线图时出错: {e}")
|
||
return {}
|
||
|
||
|
||
class ReportGenerator:
|
||
"""报告生成类"""
|
||
|
||
def __init__(self, output_dir: str = "./reports"):
|
||
"""
|
||
初始化报告生成类
|
||
|
||
Args:
|
||
output_dir: 输出目录
|
||
"""
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
def generate_training_summary(self, models_dir: str, output_path: Optional[str] = None) -> str:
|
||
"""
|
||
生成模型训练摘要报告(training_summary.csv)
|
||
|
||
Args:
|
||
models_dir: 模型保存目录
|
||
output_path: 输出路径(如果为None,自动生成)
|
||
|
||
Returns:
|
||
保存的文件路径
|
||
"""
|
||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||
import joblib
|
||
|
||
modeler = WaterQualityModelingBatch(models_dir)
|
||
models_path = Path(models_dir)
|
||
all_results = []
|
||
|
||
# 递归扫描 *.joblib 和 *.pkl,兼容 artifacts_dir/target_name/ 的所有子目录层级
|
||
model_files = list(models_path.rglob("*.joblib")) + list(models_path.rglob("*.pkl"))
|
||
|
||
for model_file in model_files:
|
||
# 目标参数取直系父目录名(符合 artifacts_dir/target_name/ 结构)
|
||
target_name = model_file.parent.name
|
||
stem = model_file.stem
|
||
|
||
# 文件名格式:{safe_target}_{preprocess}_{model_name}.joblib
|
||
# 使用 split('_', 2) 最多切 3 段,目标 1 段、预处理 1 段、模型 1 段
|
||
parts = stem.split('_', 2)
|
||
preprocess = parts[1] if len(parts) > 1 else 'Unknown'
|
||
model_name_str = parts[2] if len(parts) > 2 else stem
|
||
|
||
# 尝试从 joblib/pkl 读取元数据,提取性能指标
|
||
metrics = {}
|
||
try:
|
||
data = joblib.load(model_file)
|
||
metadata = data.get('metadata', {})
|
||
metrics = {
|
||
'train_r2': metadata.get('train_r2', 'N/A'),
|
||
'test_r2': metadata.get('test_r2', 'N/A'),
|
||
'test_rmse': metadata.get('test_rmse', 'N/A'),
|
||
'train_rmse': metadata.get('train_rmse', 'N/A'),
|
||
'train_mae': metadata.get('train_mae', 'N/A'),
|
||
'test_mae': metadata.get('test_mae', 'N/A'),
|
||
'cv_mean': metadata.get('cv_mean', 'N/A'),
|
||
}
|
||
except Exception:
|
||
pass # 加载失败时 metrics 保持为空字典,摘要中该列为 N/A
|
||
|
||
all_results.append({
|
||
'target': target_name,
|
||
'model_file': str(model_file),
|
||
'preprocess': preprocess,
|
||
'model': model_name_str,
|
||
**metrics,
|
||
})
|
||
|
||
summary_data = []
|
||
for result in all_results:
|
||
summary_data.append({
|
||
'目标参数': result['target'],
|
||
'预处理方法': result['preprocess'],
|
||
'模型名称': result['model'],
|
||
'模型文件': result['model_file'],
|
||
'训练集R²': result.get('train_r2', 'N/A'),
|
||
'测试集R²': result.get('test_r2', 'N/A'),
|
||
'测试集RMSE': result.get('test_rmse', 'N/A'),
|
||
'训练集RMSE': result.get('train_rmse', 'N/A'),
|
||
'训练集MAE': result.get('train_mae', 'N/A'),
|
||
'测试集MAE': result.get('test_mae', 'N/A'),
|
||
'CV均值': result.get('cv_mean', 'N/A'),
|
||
})
|
||
|
||
if not summary_data:
|
||
print("警告:未找到模型文件,生成空摘要")
|
||
summary_data = [{
|
||
'目标参数': 'No Data',
|
||
'预处理方法': 'N/A',
|
||
'模型名称': 'N/A',
|
||
'模型文件': 'N/A',
|
||
'训练集R²': 'N/A',
|
||
'测试集R²': 'N/A',
|
||
'测试集RMSE': 'N/A',
|
||
'训练集RMSE': 'N/A',
|
||
'训练集MAE': 'N/A',
|
||
'测试集MAE': 'N/A',
|
||
'CV均值': 'N/A',
|
||
}]
|
||
|
||
df_summary = pd.DataFrame(summary_data)
|
||
|
||
if output_path is None:
|
||
output_path = self.output_dir / "training_summary.csv"
|
||
else:
|
||
output_path = Path(output_path)
|
||
|
||
df_summary.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||
print(f"训练摘要报告已保存: {output_path}")
|
||
return str(output_path)
|
||
|
||
def generate_prediction_report(self, prediction_csv_paths: Dict[str, str],
|
||
output_path: Optional[str] = None) -> str:
|
||
"""
|
||
生成参数反演结果报告(包含预测统计信息)
|
||
|
||
Args:
|
||
prediction_csv_paths: 预测结果文件路径字典(键为目标参数名)
|
||
output_path: 输出路径(如果为None,自动生成)
|
||
|
||
Returns:
|
||
保存的文件路径
|
||
"""
|
||
report_data = []
|
||
|
||
for target_name, csv_path in prediction_csv_paths.items():
|
||
try:
|
||
df = pd.read_csv(csv_path)
|
||
|
||
# 假设预测值列名为'prediction'或最后一列
|
||
if 'prediction' in df.columns:
|
||
pred_col = 'prediction'
|
||
else:
|
||
pred_col = df.columns[-1]
|
||
|
||
predictions = df[pred_col].dropna()
|
||
|
||
stats = {
|
||
'目标参数': target_name,
|
||
'样本数量': len(predictions),
|
||
'均值': predictions.mean(),
|
||
'标准差': predictions.std(),
|
||
'最小值': predictions.min(),
|
||
'最大值': predictions.max(),
|
||
'中位数': predictions.median(),
|
||
'25%分位数': predictions.quantile(0.25),
|
||
'75%分位数': predictions.quantile(0.75),
|
||
'文件路径': csv_path
|
||
}
|
||
|
||
report_data.append(stats)
|
||
except Exception as e:
|
||
print(f"处理文件 {csv_path} 时出错: {e}")
|
||
report_data.append({
|
||
'目标参数': target_name,
|
||
'样本数量': 0,
|
||
'错误': str(e)
|
||
})
|
||
|
||
df_report = pd.DataFrame(report_data)
|
||
|
||
if output_path is None:
|
||
output_path = self.output_dir / "prediction_report.csv"
|
||
else:
|
||
output_path = Path(output_path)
|
||
|
||
df_report.to_csv(output_path, index=False, encoding='utf-8-sig', float_format='%.6f')
|
||
print(f"预测结果报告已保存: {output_path}")
|
||
return str(output_path)
|
||
|
||
def generate_batch_inference_summary(self, pipeline_info: Dict,
|
||
output_path: Optional[str] = None) -> str:
|
||
"""
|
||
生成批量处理摘要(batch_inference_summary.json)
|
||
|
||
Args:
|
||
pipeline_info: 流程信息字典,包含各步骤的执行情况
|
||
output_path: 输出路径(如果为None,自动生成)
|
||
|
||
Returns:
|
||
保存的文件路径
|
||
"""
|
||
summary = {
|
||
'执行时间': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'工作目录': str(pipeline_info.get('work_dir', 'Unknown')),
|
||
'步骤执行情况': {},
|
||
'模型训练': {},
|
||
'预测结果': {},
|
||
'输出文件': {}
|
||
}
|
||
|
||
# 添加步骤执行情况
|
||
for step in ['step1', 'step2', 'step3', 'step4', 'step5', 'step6', 'step7', 'step8', 'step9']:
|
||
if step in pipeline_info:
|
||
summary['步骤执行情况'][step] = {
|
||
'状态': pipeline_info[step].get('status', 'completed'),
|
||
'输出文件': pipeline_info[step].get('output_file', 'N/A')
|
||
}
|
||
|
||
# 添加模型训练信息
|
||
if 'models_dir' in pipeline_info:
|
||
summary['模型训练']['模型目录'] = pipeline_info['models_dir']
|
||
summary['模型训练']['训练参数'] = pipeline_info.get('training_params', {})
|
||
|
||
# 添加预测结果信息
|
||
if 'prediction_files' in pipeline_info:
|
||
summary['预测结果'] = {
|
||
'目标参数数量': len(pipeline_info['prediction_files']),
|
||
'预测文件': pipeline_info['prediction_files']
|
||
}
|
||
|
||
# 添加输出文件列表
|
||
summary['输出文件'] = pipeline_info.get('output_files', {})
|
||
|
||
if output_path is None:
|
||
output_path = self.output_dir / "batch_inference_summary.json"
|
||
else:
|
||
output_path = Path(output_path)
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"批量处理摘要已保存: {output_path}")
|
||
return str(output_path)
|
||
|