fix(viz_reports): plot_scatter_true_vs_pred NaN 容错 + subplots_adjust 替换 tight_layout
This commit is contained in:
@ -82,6 +82,10 @@ class WaterQualityVisualization:
|
|||||||
"""
|
"""
|
||||||
fig, ax = plt.subplots(figsize=(10, 8))
|
fig, ax = plt.subplots(figsize=(10, 8))
|
||||||
|
|
||||||
|
# 强制处理 NaN
|
||||||
|
y_true = np.nan_to_num(y_true, nan=0.0)
|
||||||
|
y_pred = np.nan_to_num(y_pred, nan=0.0)
|
||||||
|
|
||||||
# 计算所有数据的R²和RMSE
|
# 计算所有数据的R²和RMSE
|
||||||
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
||||||
|
|
||||||
@ -133,7 +137,7 @@ class WaterQualityVisualization:
|
|||||||
ax.set_xlabel(f'真实值 ({target_name})', fontsize=14, fontweight='bold')
|
ax.set_xlabel(f'真实值 ({target_name})', fontsize=14, fontweight='bold')
|
||||||
ax.set_ylabel(f'预测值 ({target_name})', fontsize=14, fontweight='bold')
|
ax.set_ylabel(f'预测值 ({target_name})', fontsize=14, fontweight='bold')
|
||||||
ax.set_title(f'{target_name} - 真实值 vs 预测值', fontsize=16, fontweight='bold')
|
ax.set_title(f'{target_name} - 真实值 vs 预测值', fontsize=16, fontweight='bold')
|
||||||
ax.legend(loc='upper left', fontsize=11, bbox_to_anchor=(1.02, 1), borderaxespad=0)
|
ax.legend(loc='upper left', fontsize=11, bbox_to_anchor=(1.05, 1), borderaxespad=0)
|
||||||
ax.grid(True, alpha=0.3)
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
# 添加指标文本框
|
# 添加指标文本框
|
||||||
@ -141,7 +145,7 @@ class WaterQualityVisualization:
|
|||||||
verticalalignment='top', bbox=dict(boxstyle='round',
|
verticalalignment='top', bbox=dict(boxstyle='round',
|
||||||
facecolor='wheat', alpha=0.8), fontsize=10)
|
facecolor='wheat', alpha=0.8), fontsize=10)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.subplots_adjust(right=0.85)
|
||||||
|
|
||||||
# 保存图片
|
# 保存图片
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user