diff --git a/src/postprocessing/visualization_reports.py b/src/postprocessing/visualization_reports.py index 65b5a78..2ce00c7 100644 --- a/src/postprocessing/visualization_reports.py +++ b/src/postprocessing/visualization_reports.py @@ -81,7 +81,11 @@ class WaterQualityVisualization: 保存的文件路径 """ fig, ax = plt.subplots(figsize=(10, 8)) - + + # 强制处理 NaN + y_true = np.nan_to_num(y_true, nan=0.0) + y_pred = np.nan_to_num(y_pred, nan=0.0) + # 计算所有数据的R²和RMSE from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error @@ -133,7 +137,7 @@ class WaterQualityVisualization: ax.set_xlabel(f'真实值 ({target_name})', fontsize=14, fontweight='bold') ax.set_ylabel(f'预测值 ({target_name})', fontsize=14, fontweight='bold') ax.set_title(f'{target_name} - 真实值 vs 预测值', fontsize=16, fontweight='bold') - ax.legend(loc='upper left', fontsize=11, bbox_to_anchor=(1.02, 1), borderaxespad=0) + ax.legend(loc='upper left', fontsize=11, bbox_to_anchor=(1.05, 1), borderaxespad=0) ax.grid(True, alpha=0.3) # 添加指标文本框 @@ -141,7 +145,7 @@ class WaterQualityVisualization: verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8), fontsize=10) - plt.tight_layout() + plt.subplots_adjust(right=0.85) # 保存图片 if output_path is None: