fix: 修复工作目录与步骤名不对应、回归预测虚数报错、模型加载及预处理名称转换问题,重构可视化并修正勾选联动

This commit is contained in:
2026-04-14 17:41:38 +08:00
parent b0a94ba1e7
commit 9b7bcfadd1
17 changed files with 12470 additions and 3113 deletions

View File

@ -281,11 +281,13 @@ class WaterQualityVisualization:
def plot_statistical_charts(self, csv_path: str, parameter_columns: List[str],
output_dir: Optional[str] = None) -> Dict[str, str]:
"""
绘制统计图表:箱线图、直方图、相关性热力图
绘制统计图表:**只针对水质参数列**(数值型,排除波长列)
- 水质参数列(如浓度、含量等数值型参数)使用箱线图/直方图/相关性热力图
- 排除光谱波长列(虽然也是数值型,但不是水质参数)
Args:
csv_path: CSV文件路径
parameter_columns: 参数列名列表
parameter_columns: **水质参数**列名列表(数值型,已排除波长列)
output_dir: 输出目录
Returns:
@ -301,12 +303,16 @@ class WaterQualityVisualization:
output_paths = {}
# 水质参数统计图表(针对数值型参数,排除波长列)
# 假设传入的 parameter_columns 已经是过滤后的水质参数列
numeric_cols = [col for col in parameter_columns if col in df.columns and pd.api.types.is_numeric_dtype(df[col])]
# 1. 箱线图
if len(parameter_columns) > 0:
if len(numeric_cols) > 0:
fig, ax = plt.subplots(figsize=(12, 6))
data_for_boxplot = [df[col].dropna() for col in parameter_columns if col in df.columns]
data_for_boxplot = [df[col].dropna() for col in numeric_cols]
if data_for_boxplot:
ax.boxplot(data_for_boxplot, labels=[col for col in parameter_columns if col in df.columns])
ax.boxplot(data_for_boxplot, labels=numeric_cols)
ax.set_ylabel('数值', fontsize=12, fontweight='bold')
ax.set_title('水质参数箱线图', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
@ -318,51 +324,51 @@ class WaterQualityVisualization:
plt.close()
output_paths['boxplot'] = str(boxplot_path)
# 2. 直方图
for col in parameter_columns:
if col not in df.columns:
continue
# 2. 直方图 (每个水质参数列)
for col in numeric_cols:
fig, ax = plt.subplots(figsize=(10, 6))
data = df[col].dropna()
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
# 添加统计信息
mean_val = data.mean()
std_val = data.std()
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
ax.legend()
plt.tight_layout()
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
hist_path = output_dir / f"{safe_name}_histogram.png"
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths[f'histogram_{col}'] = str(hist_path)
# 3. 相关性热力图
if len(parameter_columns) >= 2:
valid_cols = [col for col in parameter_columns if col in df.columns]
if len(valid_cols) >= 2:
corr_matrix = df[valid_cols].corr()
if len(data) > 1:
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
# 添加统计信息
mean_val = data.mean()
std_val = data.std() if len(data) > 1 else 0
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
ax.legend()
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
ax=ax, vmin=-1, vmax=1)
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
plt.tight_layout()
heatmap_path = output_dir / "correlation_heatmap.png"
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
hist_path = output_dir / f"{safe_name}_histogram.png"
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths['heatmap'] = str(heatmap_path)
output_paths[f'histogram_{col}'] = str(hist_path)
print(f"统计图表已保存到: {output_dir}")
# 3. 相关性热力图
if len(numeric_cols) >= 2:
corr_matrix = df[numeric_cols].corr()
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
ax=ax, vmin=-1, vmax=1)
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
plt.tight_layout()
heatmap_path = output_dir / "correlation_heatmap.png"
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths['heatmap'] = str(heatmap_path)
if not output_paths:
print("警告: 没有生成任何统计图表(可能无合适的水质参数列)")
else:
print(f"统计图表已保存到: {output_dir},共 {len(output_paths)} 个文件")
return output_paths
def plot_distribution_map_enhanced(self, prediction_csv_path: str,