fix: 修复工作目录与步骤名不对应、回归预测虚数报错、模型加载及预处理名称转换问题,重构可视化并修正勾选联动
This commit is contained in:
@ -281,11 +281,13 @@ class WaterQualityVisualization:
|
||||
def plot_statistical_charts(self, csv_path: str, parameter_columns: List[str],
|
||||
output_dir: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
绘制统计图表:箱线图、直方图、相关性热力图
|
||||
绘制统计图表:**只针对水质参数列**(数值型,排除波长列)
|
||||
- 水质参数列(如浓度、含量等数值型参数)使用箱线图/直方图/相关性热力图
|
||||
- 排除光谱波长列(虽然也是数值型,但不是水质参数)
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径
|
||||
parameter_columns: 参数列名列表
|
||||
parameter_columns: **水质参数**列名列表(数值型,已排除波长列)
|
||||
output_dir: 输出目录
|
||||
|
||||
Returns:
|
||||
@ -301,12 +303,16 @@ class WaterQualityVisualization:
|
||||
|
||||
output_paths = {}
|
||||
|
||||
# 水质参数统计图表(针对数值型参数,排除波长列)
|
||||
# 假设传入的 parameter_columns 已经是过滤后的水质参数列
|
||||
numeric_cols = [col for col in parameter_columns if col in df.columns and pd.api.types.is_numeric_dtype(df[col])]
|
||||
|
||||
# 1. 箱线图
|
||||
if len(parameter_columns) > 0:
|
||||
if len(numeric_cols) > 0:
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
data_for_boxplot = [df[col].dropna() for col in parameter_columns if col in df.columns]
|
||||
data_for_boxplot = [df[col].dropna() for col in numeric_cols]
|
||||
if data_for_boxplot:
|
||||
ax.boxplot(data_for_boxplot, labels=[col for col in parameter_columns if col in df.columns])
|
||||
ax.boxplot(data_for_boxplot, labels=numeric_cols)
|
||||
ax.set_ylabel('数值', fontsize=12, fontweight='bold')
|
||||
ax.set_title('水质参数箱线图', fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
@ -318,51 +324,51 @@ class WaterQualityVisualization:
|
||||
plt.close()
|
||||
output_paths['boxplot'] = str(boxplot_path)
|
||||
|
||||
# 2. 直方图
|
||||
for col in parameter_columns:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
# 2. 直方图 (每个水质参数列)
|
||||
for col in numeric_cols:
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
data = df[col].dropna()
|
||||
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
|
||||
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
|
||||
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
|
||||
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 添加统计信息
|
||||
mean_val = data.mean()
|
||||
std_val = data.std()
|
||||
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
|
||||
ax.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
|
||||
hist_path = output_dir / f"{safe_name}_histogram.png"
|
||||
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
output_paths[f'histogram_{col}'] = str(hist_path)
|
||||
|
||||
# 3. 相关性热力图
|
||||
if len(parameter_columns) >= 2:
|
||||
valid_cols = [col for col in parameter_columns if col in df.columns]
|
||||
if len(valid_cols) >= 2:
|
||||
corr_matrix = df[valid_cols].corr()
|
||||
if len(data) > 1:
|
||||
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
|
||||
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
|
||||
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
|
||||
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 添加统计信息
|
||||
mean_val = data.mean()
|
||||
std_val = data.std() if len(data) > 1 else 0
|
||||
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
|
||||
ax.legend()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
|
||||
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
|
||||
ax=ax, vmin=-1, vmax=1)
|
||||
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
heatmap_path = output_dir / "correlation_heatmap.png"
|
||||
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
|
||||
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
|
||||
hist_path = output_dir / f"{safe_name}_histogram.png"
|
||||
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
output_paths['heatmap'] = str(heatmap_path)
|
||||
output_paths[f'histogram_{col}'] = str(hist_path)
|
||||
|
||||
print(f"统计图表已保存到: {output_dir}")
|
||||
# 3. 相关性热力图
|
||||
if len(numeric_cols) >= 2:
|
||||
corr_matrix = df[numeric_cols].corr()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
|
||||
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
|
||||
ax=ax, vmin=-1, vmax=1)
|
||||
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
heatmap_path = output_dir / "correlation_heatmap.png"
|
||||
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
output_paths['heatmap'] = str(heatmap_path)
|
||||
|
||||
if not output_paths:
|
||||
print("警告: 没有生成任何统计图表(可能无合适的水质参数列)")
|
||||
else:
|
||||
print(f"统计图表已保存到: {output_dir},共 {len(output_paths)} 个文件")
|
||||
return output_paths
|
||||
|
||||
def plot_distribution_map_enhanced(self, prediction_csv_path: str,
|
||||
|
||||
Reference in New Issue
Block a user