feat: Step1~Step14 面板单步按钮 EventBus 解耦 + Handler 补全(Step8~Step14)+ 旧上帝类删除
- 9 个面板(step1~step6/step8_ml_train/step8_qaa/step9_ml_predict/step10)单步执行按钮从 parent 链上溯改为 global_event_bus.publish('RequestRunSingleStep')
- PipelineExecutor 新增 _on_request_run_single_step 订阅
- 新增 Handler: step8_ml_train / step9_ml_predict / step10_qaa_inversion / step11_concentration / step12_kriging / step13_visualization / step14_report
- 删除旧 water_quality_inversion_pipeline_GUI.py(上帝类已肢解完毕)
This commit is contained in:
@ -18,6 +18,13 @@ from src.core.handlers.step4_sampling import Step4SamplingHandler
|
||||
from src.core.handlers.step5_process_csv import Step5ProcessCsvHandler
|
||||
from src.core.handlers.step6_extract_spectra import Step6ExtractSpectraHandler
|
||||
from src.core.handlers.step7_calc_indices import Step7CalcIndicesHandler
|
||||
from src.core.handlers.step8_ml_train import Step8MlTrainHandler
|
||||
from src.core.handlers.step9_ml_predict import Step9MlPredictHandler
|
||||
from src.core.handlers.step10_qaa_inversion import Step10QaaInversionHandler
|
||||
from src.core.handlers.step11_concentration import Step11ConcentrationHandler
|
||||
from src.core.handlers.step12_kriging import Step12KrigingHandler
|
||||
from src.core.handlers.step13_visualization import Step13VisualizationHandler
|
||||
from src.core.handlers.step14_report import Step14ReportHandler
|
||||
|
||||
__all__ = [
|
||||
'BaseStepHandler',
|
||||
@ -29,4 +36,11 @@ __all__ = [
|
||||
'Step5ProcessCsvHandler',
|
||||
'Step6ExtractSpectraHandler',
|
||||
'Step7CalcIndicesHandler',
|
||||
'Step8MlTrainHandler',
|
||||
'Step9MlPredictHandler',
|
||||
'Step10QaaInversionHandler',
|
||||
'Step11ConcentrationHandler',
|
||||
'Step12KrigingHandler',
|
||||
'Step13VisualizationHandler',
|
||||
'Step14ReportHandler',
|
||||
]
|
||||
|
||||
@ -74,6 +74,11 @@ class PipelineContext:
|
||||
self.training_csv_path: Optional[str] = None
|
||||
self.indices_path: Optional[str] = None
|
||||
self.custom_regression_path: Optional[str] = None
|
||||
self.sampling_csv_path: Optional[str] = None
|
||||
self.prediction_files: Dict[str, str] = {}
|
||||
self.distribution_map_path: Optional[str] = None
|
||||
self.qaa_output_path: Optional[str] = None
|
||||
self.concentration_output_path: Optional[str] = None
|
||||
|
||||
# ── 计时 ──
|
||||
self.step_timings: Dict[str, dict] = {}
|
||||
|
||||
@ -18,6 +18,13 @@ from src.core.handlers.step4_sampling import Step4SamplingHandler
|
||||
from src.core.handlers.step5_process_csv import Step5ProcessCsvHandler
|
||||
from src.core.handlers.step6_extract_spectra import Step6ExtractSpectraHandler
|
||||
from src.core.handlers.step7_calc_indices import Step7CalcIndicesHandler
|
||||
from src.core.handlers.step8_ml_train import Step8MlTrainHandler
|
||||
from src.core.handlers.step9_ml_predict import Step9MlPredictHandler
|
||||
from src.core.handlers.step10_qaa_inversion import Step10QaaInversionHandler
|
||||
from src.core.handlers.step11_concentration import Step11ConcentrationHandler
|
||||
from src.core.handlers.step12_kriging import Step12KrigingHandler
|
||||
from src.core.handlers.step13_visualization import Step13VisualizationHandler
|
||||
from src.core.handlers.step14_report import Step14ReportHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.core.handlers.pipeline_scheduler import PipelineScheduler
|
||||
@ -41,3 +48,10 @@ def register_all_handlers(scheduler: PipelineScheduler):
|
||||
scheduler.register_handler(Step5ProcessCsvHandler())
|
||||
scheduler.register_handler(Step6ExtractSpectraHandler())
|
||||
scheduler.register_handler(Step7CalcIndicesHandler())
|
||||
scheduler.register_handler(Step8MlTrainHandler())
|
||||
scheduler.register_handler(Step9MlPredictHandler())
|
||||
scheduler.register_handler(Step10QaaInversionHandler())
|
||||
scheduler.register_handler(Step11ConcentrationHandler())
|
||||
scheduler.register_handler(Step12KrigingHandler())
|
||||
scheduler.register_handler(Step13VisualizationHandler())
|
||||
scheduler.register_handler(Step14ReportHandler())
|
||||
|
||||
137
src/core/handlers/step10_qaa_inversion.py
Normal file
137
src/core/handlers/step10_qaa_inversion.py
Normal file
@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step10 处理器:QAA 准解析算法反演
|
||||
|
||||
将原 WaterQualityInversionPipeline.step8_qaa_inversion() 方法
|
||||
剥离为独立的 Step10QaaInversionHandler。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class Step10QaaInversionHandler(BaseStepHandler):
|
||||
"""步骤10:QAA 准解析算法反演(非经验模型)。
|
||||
|
||||
对应 config key: 'step10_qaa'
|
||||
直接使用 QAABaselineSolver 进行物理推导。
|
||||
"""
|
||||
|
||||
step_key = 'step10_qaa'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
from src.core.algorithms.qaa.qaas_baseline import QAABaselineSolver
|
||||
from src.utils.water_owt_config import get_lambda_0
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
lake_name = config.get('lake_name', 'Unknown')
|
||||
lambda_0 = config.get('lambda_0', get_lambda_0(lake_name))
|
||||
output_dir = os.path.join(context.work_dir, "10_QAA_Inversion")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = config.get('output_path') or os.path.join(output_dir, "a_lambda_results.csv")
|
||||
|
||||
spectrum_csv = config.get('spectrum_csv_path')
|
||||
if not spectrum_csv:
|
||||
spectrum_csv = context.training_csv_path
|
||||
if not spectrum_csv or not os.path.exists(spectrum_csv):
|
||||
fallback_candidates = []
|
||||
step6_dir = os.path.join(context.work_dir, "6_Spectral_Feature_Extraction")
|
||||
if os.path.isdir(step6_dir):
|
||||
for f in sorted(os.listdir(step6_dir)):
|
||||
if f.lower().endswith('.csv'):
|
||||
fallback_candidates.append(os.path.join(step6_dir, f))
|
||||
if fallback_candidates:
|
||||
spectrum_csv = fallback_candidates[0]
|
||||
context.notify('step10_qaa', 'info',
|
||||
f'spectrum_csv_path 为空,已自动回退到 step6 产物: {spectrum_csv}')
|
||||
else:
|
||||
msg = f'训练光谱 CSV 不存在或路径为空: {spectrum_csv}'
|
||||
context.notify('step10_qaa', 'error', msg)
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤10: QAA 反演", step_start_time, step_end_time,
|
||||
status="failed", error=msg
|
||||
)
|
||||
return {'error': msg}
|
||||
|
||||
try:
|
||||
df = pd.read_csv(spectrum_csv, encoding="utf-8-sig")
|
||||
col_names = df.columns.tolist()
|
||||
|
||||
wavelength_col_idx = None
|
||||
for i, col in enumerate(col_names):
|
||||
try:
|
||||
float(col)
|
||||
wavelength_col_idx = i
|
||||
break
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if wavelength_col_idx is None:
|
||||
msg = "无法从 CSV 列名中识别波长信息"
|
||||
context.notify('step10_qaa', 'error', msg)
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤10: QAA 反演", step_start_time, step_end_time,
|
||||
status="failed", error=msg
|
||||
)
|
||||
return {'error': msg}
|
||||
|
||||
meta_df = df.iloc[:, :wavelength_col_idx].copy()
|
||||
wavelengths = np.array([float(c) for c in col_names[wavelength_col_idx:]], dtype=np.float64)
|
||||
data_matrix = df.iloc[:, wavelength_col_idx:].values.astype(np.float64)
|
||||
if data_matrix.ndim == 1:
|
||||
data_matrix = data_matrix[np.newaxis, :]
|
||||
|
||||
solver = QAABaselineSolver()
|
||||
raw_result = solver.run_inversion(wavelengths, data_matrix, lambda_0)
|
||||
|
||||
if isinstance(raw_result, list):
|
||||
sample_results = raw_result
|
||||
else:
|
||||
sample_results = [raw_result]
|
||||
|
||||
rows_out = []
|
||||
for i, sample_result in enumerate(sample_results):
|
||||
wl_arr = wavelengths
|
||||
a_arr = sample_result['a_lambda']
|
||||
bb_arr = sample_result['bb_lambda']
|
||||
meta_row = meta_df.iloc[i].to_dict() if i < len(meta_df) else {}
|
||||
for j, wl in enumerate(wl_arr):
|
||||
rows_out.append({
|
||||
'sample_id': f"sample_{i}",
|
||||
'Wavelength': wl,
|
||||
'a_lambda': a_arr[j],
|
||||
'bb_lambda': bb_arr[j],
|
||||
**meta_row,
|
||||
})
|
||||
|
||||
result_df = pd.DataFrame(rows_out)
|
||||
result_df.to_csv(output_path, index=False, float_format='%.8f')
|
||||
|
||||
context.qaa_output_path = output_path
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤10: QAA 反演", step_start_time, step_end_time
|
||||
)
|
||||
context.notify('step10_qaa', 'completed',
|
||||
f"QAA 反演完毕,水域={lake_name},λ₀={lambda_0}nm")
|
||||
|
||||
return {'qaa_output_path': output_path}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤10: QAA 反演", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
71
src/core/handlers/step11_concentration.py
Normal file
71
src/core/handlers/step11_concentration.py
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step11 处理器:浓度反演
|
||||
|
||||
将原 WaterQualityInversionPipeline.step9_concentration_inversion() 方法
|
||||
剥离为独立的 Step11ConcentrationHandler。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class Step11ConcentrationHandler(BaseStepHandler):
|
||||
"""步骤11:浓度反演(基于 QAA Step10 输出的 a_lambda/bb_lambda)。
|
||||
|
||||
对应 config key: 'step11_concentration'
|
||||
直接使用 ConcentrationPipeline 进行浓度反演。
|
||||
"""
|
||||
|
||||
step_key = 'step11_concentration'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
from src.core.algorithms.concentration_inversion import ConcentrationPipeline
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
input_csv = config.get('input_csv') or context.qaa_output_path
|
||||
output_csv = config.get('output_csv')
|
||||
lake_case = config.get('lake_case', 'medium')
|
||||
|
||||
if not input_csv or not os.path.exists(input_csv):
|
||||
msg = f"QAA 结果文件不存在或路径为空: {input_csv}"
|
||||
context.notify('step11_concentration', 'error', msg)
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤11: 浓度反演", step_start_time, step_end_time,
|
||||
status="failed", error=msg
|
||||
)
|
||||
return {'error': msg}
|
||||
|
||||
if not output_csv:
|
||||
output_dir = os.path.join(context.work_dir, "11_Concentration")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_csv = os.path.join(output_dir, "final_concentrations.csv")
|
||||
|
||||
try:
|
||||
pipeline = ConcentrationPipeline(lake_case=lake_case)
|
||||
result_csv = pipeline.run_pipeline(input_csv, output_csv)
|
||||
|
||||
context.concentration_output_path = result_csv
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤11: 浓度反演", step_start_time, step_end_time
|
||||
)
|
||||
context.notify('step11_concentration', 'completed',
|
||||
f"浓度反演完毕,结果保存于: {result_csv}")
|
||||
|
||||
return {'concentration_output_path': result_csv}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤11: 浓度反演", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
81
src/core/handlers/step12_kriging.py
Normal file
81
src/core/handlers/step12_kriging.py
Normal file
@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step12 处理器:克里金空间插值与分布图生成
|
||||
|
||||
将原 WaterQualityInversionPipeline.step10_map() 方法
|
||||
剥离为独立的 Step12KrigingHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.mapping_step import MappingStep
|
||||
|
||||
|
||||
class Step12KrigingHandler(BaseStepHandler):
|
||||
"""步骤12:克里金空间插值与分布图生成。
|
||||
|
||||
对应 config key: 'step12_kriging'
|
||||
委托类: MappingStep.generate_distribution_map()
|
||||
"""
|
||||
|
||||
step_key = 'step12_kriging'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
prediction_csv_path = config.get('prediction_csv_path')
|
||||
boundary_shp_path = config.get('boundary_shp_path')
|
||||
|
||||
# 强制输出到 visualization_dir
|
||||
csv_name = Path(prediction_csv_path).stem if prediction_csv_path else "distribution"
|
||||
forced_image_path = str(context.visualization_dir / f"{csv_name}_distribution.png")
|
||||
viz_dir_resolved = str(context.visualization_dir)
|
||||
|
||||
output_image_path = config.get('output_image_path')
|
||||
if output_image_path and output_image_path != forced_image_path:
|
||||
norm_user = output_image_path.replace('\\', '/').rstrip('/')
|
||||
norm_viz = viz_dir_resolved.replace('\\', '/').rstrip('/')
|
||||
if not norm_user.startswith(norm_viz + '/') and norm_user != norm_viz:
|
||||
output_image_path = forced_image_path
|
||||
else:
|
||||
output_image_path = forced_image_path
|
||||
|
||||
try:
|
||||
result = MappingStep.generate_distribution_map(
|
||||
prediction_csv_path=prediction_csv_path,
|
||||
boundary_shp_path=boundary_shp_path,
|
||||
output_image_path=output_image_path,
|
||||
resolution=config.get('resolution', 30),
|
||||
input_crs=config.get('input_crs', 'EPSG:32651'),
|
||||
output_crs=config.get('output_crs', 'EPSG:4326'),
|
||||
show_sample_points=config.get('show_sample_points', False),
|
||||
base_map_tif=config.get('base_map_tif'),
|
||||
use_distance_diffusion=config.get('use_distance_diffusion', True),
|
||||
max_diffusion_distance=config.get('max_diffusion_distance'),
|
||||
diffusion_power=config.get('diffusion_power', 2),
|
||||
diffusion_n_neighbors=config.get('diffusion_n_neighbors', 15),
|
||||
cmap=config.get('cmap'),
|
||||
expand_ratio=config.get('expand_ratio', 0.05),
|
||||
output_dir=str(context.visualization_dir),
|
||||
)
|
||||
|
||||
context.distribution_map_path = result
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤12: 克里金插值与分布图", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'distribution_map_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤12: 克里金插值与分布图", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
349
src/core/handlers/step13_visualization.py
Normal file
349
src/core/handlers/step13_visualization.py
Normal file
@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step13 处理器:可视化成图
|
||||
|
||||
将原 WaterQualityInversionPipeline 中的可视化方法
|
||||
(散点图、箱型图、光谱曲线、统计图表、耀斑预览)
|
||||
剥离为独立的 Step13VisualizationHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class Step13VisualizationHandler(BaseStepHandler):
|
||||
"""步骤13:可视化成图。
|
||||
|
||||
对应 config key: 'step13_visualization'
|
||||
包含:散点图、箱型图、光谱曲线、统计图表、耀斑预览。
|
||||
"""
|
||||
|
||||
step_key = 'step13_visualization'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
output_files: Dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
# ── 散点图 ──
|
||||
if config.get('generate_scatter', True):
|
||||
if context.training_csv_path and context.models_dir.exists():
|
||||
try:
|
||||
scatter_config = config.get('scatter_config', {})
|
||||
scatter_paths = self._generate_scatter_plots(context, scatter_config)
|
||||
output_files['scatter_plots'] = scatter_paths
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成散点图时出错: {e}")
|
||||
|
||||
# ── 箱型图 ──
|
||||
if config.get('generate_boxplots', True):
|
||||
if context.processed_csv_path:
|
||||
try:
|
||||
boxplot_config = config.get('boxplot_config', {})
|
||||
boxplot_paths = self._generate_boxplots(context, boxplot_config)
|
||||
output_files['boxplots'] = boxplot_paths
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成箱型图时出错: {e}")
|
||||
|
||||
# ── 光谱曲线 ──
|
||||
if config.get('generate_spectrum', True):
|
||||
if context.training_csv_path:
|
||||
try:
|
||||
spectrum_paths = self._generate_spectrum_plots(context, config)
|
||||
output_files['spectrum_plots'] = spectrum_paths
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成光谱曲线图时出错: {e}")
|
||||
|
||||
# ── 统计图表 ──
|
||||
if config.get('generate_statistics', True):
|
||||
if context.processed_csv_path:
|
||||
try:
|
||||
stat_charts = self._generate_statistics(context)
|
||||
output_files['statistical_charts'] = stat_charts
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成统计图表时出错: {e}")
|
||||
|
||||
# ── 耀斑预览 ──
|
||||
if config.get('generate_glint_previews', True):
|
||||
try:
|
||||
glint_config = config.get('glint_preview_config', {})
|
||||
preview_paths = context.visualizer.generate_glint_deglint_previews(
|
||||
work_dir=glint_config.get('work_dir') or str(context.work_dir),
|
||||
output_subdir=glint_config.get('output_subdir', 'glint_deglint_previews'),
|
||||
generate_glint=glint_config.get('generate_glint', True),
|
||||
generate_deglint=glint_config.get('generate_deglint', True),
|
||||
)
|
||||
output_files['glint_deglint_previews'] = preview_paths
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成耀斑预览图时出错: {e}")
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤13: 可视化成图", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'visualization_outputs': output_files}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤13: 可视化成图", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
# ── 散点图 ──
|
||||
|
||||
def _generate_scatter_plots(self, context: PipelineContext,
|
||||
scatter_config: dict) -> Dict[str, str]:
|
||||
training_csv_path = context.training_csv_path
|
||||
models_dir = str(context.models_dir)
|
||||
metric = scatter_config.get('metric', 'test_r2')
|
||||
use_enhanced = scatter_config.get('use_enhanced', True)
|
||||
feature_start_column = scatter_config.get('feature_start_column', 13)
|
||||
test_size = scatter_config.get('test_size', 0.2)
|
||||
random_state = scatter_config.get('random_state', 42)
|
||||
|
||||
scatter_paths = {}
|
||||
|
||||
if use_enhanced:
|
||||
try:
|
||||
results = context.scatter_batch.batch_plot_scatter(
|
||||
models_root_dir=models_dir,
|
||||
csv_path=training_csv_path,
|
||||
output_dir=str(context.visualization_dir / "scatter_plots"),
|
||||
metric=metric,
|
||||
target_column=None,
|
||||
feature_start_column=feature_start_column,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
)
|
||||
for target_name, result in results.items():
|
||||
if result.get('status') == 'success':
|
||||
scatter_paths[target_name] = result.get('save_path', '')
|
||||
except Exception:
|
||||
use_enhanced = False
|
||||
|
||||
if not use_enhanced or not scatter_paths:
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
models_path = Path(models_dir)
|
||||
for target_folder in models_path.iterdir():
|
||||
if not target_folder.is_dir():
|
||||
continue
|
||||
target_name = target_folder.name
|
||||
try:
|
||||
inferencer = WaterQualityInference(str(target_folder))
|
||||
eval_result = inferencer.evaluate_with_split(
|
||||
data_csv_path=training_csv_path,
|
||||
split_method="spxy",
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
metric=metric,
|
||||
)
|
||||
predictions = eval_result.get('predictions', {})
|
||||
if predictions:
|
||||
y_train_true = predictions.get('y_train_true')
|
||||
y_train_pred = predictions.get('y_train_pred')
|
||||
y_test_true = predictions.get('y_test_true')
|
||||
y_test_pred = predictions.get('y_test_pred')
|
||||
metrics = eval_result.get('test_metrics', {})
|
||||
if y_train_true is not None and y_test_true is not None:
|
||||
y_all_true = np.concatenate([y_train_true, y_test_true])
|
||||
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
|
||||
train_indices = np.arange(len(y_train_true))
|
||||
test_indices = np.arange(len(y_train_true), len(y_all_true))
|
||||
scatter_path = context.visualizer.plot_scatter_true_vs_pred(
|
||||
y_true=y_all_true,
|
||||
y_pred=y_all_pred,
|
||||
target_name=target_name,
|
||||
train_indices=train_indices,
|
||||
test_indices=test_indices,
|
||||
metrics={
|
||||
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
|
||||
'test_r2': metrics.get('r2', 0),
|
||||
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
|
||||
'test_rmse': metrics.get('rmse', 0),
|
||||
}
|
||||
)
|
||||
scatter_paths[target_name] = scatter_path
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return scatter_paths
|
||||
|
||||
# ── 箱型图 ──
|
||||
|
||||
def _generate_boxplots(self, context: PipelineContext,
|
||||
boxplot_config: dict) -> Dict[str, str]:
|
||||
csv_path = context.processed_csv_path
|
||||
parameter_columns = boxplot_config.get('parameter_columns')
|
||||
data_start_column = boxplot_config.get('data_start_column', 4)
|
||||
save_individual = boxplot_config.get('save_individual', True)
|
||||
use_seaborn = boxplot_config.get('use_seaborn', True)
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
if parameter_columns is None:
|
||||
data_columns = df.iloc[:, data_start_column:]
|
||||
parameter_columns = list(data_columns.columns)
|
||||
else:
|
||||
parameter_columns = [col for col in parameter_columns if col in df.columns]
|
||||
|
||||
if not parameter_columns:
|
||||
return {}
|
||||
|
||||
boxplot_dir = context.visualization_dir / "boxplots"
|
||||
boxplot_dir.mkdir(parents=True, exist_ok=True)
|
||||
boxplot_paths = {}
|
||||
|
||||
if save_individual:
|
||||
for column in parameter_columns:
|
||||
if column not in df.columns:
|
||||
continue
|
||||
clean_data = df[column].dropna()
|
||||
if len(clean_data) == 0:
|
||||
continue
|
||||
try:
|
||||
plt.figure(figsize=(8, 6))
|
||||
if use_seaborn:
|
||||
plot_data = pd.DataFrame({'参数': [column] * len(clean_data), '数值': clean_data})
|
||||
sns.boxplot(data=plot_data, x='参数', y='数值', palette='Set2')
|
||||
sns.stripplot(data=plot_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=5, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot([clean_data], labels=[column],
|
||||
patch_artist=True, showfliers=False)
|
||||
box_plot['boxes'][0].set_facecolor('lightblue')
|
||||
box_plot['boxes'][0].set_alpha(0.7)
|
||||
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
|
||||
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
stats_text = (f'数据点数: {len(clean_data)}\n'
|
||||
f'均值: {clean_data.mean():.2f}\n'
|
||||
f'中位数: {clean_data.median():.2f}\n'
|
||||
f'标准差: {clean_data.std():.2f}')
|
||||
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round',
|
||||
facecolor='wheat' if not use_seaborn else 'lightgreen',
|
||||
alpha=0.8))
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
safe_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
|
||||
save_path = boxplot_dir / f'{safe_name}_boxplot.png'
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
boxplot_paths[column] = str(save_path)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 综合箱型图
|
||||
try:
|
||||
plt.figure(figsize=(max(12, len(parameter_columns) * 0.8), 8))
|
||||
box_data = []
|
||||
labels = []
|
||||
for column in parameter_columns:
|
||||
if column in df.columns:
|
||||
clean_data = df[column].dropna()
|
||||
if len(clean_data) > 0:
|
||||
box_data.append(clean_data)
|
||||
labels.append(column)
|
||||
if box_data:
|
||||
if use_seaborn:
|
||||
melted_data = pd.melt(df[labels], var_name='参数', value_name='数值')
|
||||
melted_data = melted_data.dropna()
|
||||
sns.boxplot(data=melted_data, x='参数', y='数值', palette='Set3')
|
||||
sns.stripplot(data=melted_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=4, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True, showfliers=False)
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
|
||||
for patch, color in zip(box_plot['boxes'], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
for i, data in enumerate(box_data):
|
||||
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
|
||||
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
plt.title('水质参数箱型图(综合)', fontsize=16, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
combined_path = boxplot_dir / 'all_parameters_boxplot.png'
|
||||
plt.savefig(combined_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
boxplot_paths['all_parameters'] = str(combined_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return boxplot_paths
|
||||
|
||||
# ── 光谱曲线 ──
|
||||
|
||||
def _generate_spectrum_plots(self, context: PipelineContext,
|
||||
config: dict) -> Dict[str, str]:
|
||||
csv_path = context.training_csv_path
|
||||
wavelength_start_column = config.get('feature_start_column', 'UTM_Y')
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
if isinstance(wavelength_start_column, str):
|
||||
try:
|
||||
wavelength_start_idx = df.columns.get_loc(wavelength_start_column)
|
||||
except KeyError:
|
||||
wavelength_start_idx = 13
|
||||
else:
|
||||
wavelength_start_idx = wavelength_start_column
|
||||
|
||||
parameter_columns = list(df.columns[:wavelength_start_idx])
|
||||
if len(parameter_columns) > 2:
|
||||
parameter_columns = parameter_columns[2:]
|
||||
|
||||
spectrum_paths = {}
|
||||
for param_col in parameter_columns:
|
||||
if param_col not in df.columns:
|
||||
continue
|
||||
try:
|
||||
spectrum_path = context.visualizer.plot_spectrum_by_parameter(
|
||||
csv_path=csv_path,
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wavelength_start_column,
|
||||
n_groups=5,
|
||||
)
|
||||
spectrum_paths[param_col] = spectrum_path
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return spectrum_paths
|
||||
|
||||
# ── 统计图表 ──
|
||||
|
||||
def _generate_statistics(self, context: PipelineContext) -> Dict[str, str]:
|
||||
csv_path = context.processed_csv_path
|
||||
df = pd.read_csv(csv_path)
|
||||
parameter_columns = list(df.columns[2:])
|
||||
parameter_columns = [col for col in parameter_columns
|
||||
if df[col].dtype in [np.float64, np.int64]]
|
||||
|
||||
return context.visualizer.plot_statistical_charts(
|
||||
csv_path=csv_path,
|
||||
parameter_columns=parameter_columns,
|
||||
)
|
||||
142
src/core/handlers/step14_report.py
Normal file
142
src/core/handlers/step14_report.py
Normal file
@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step14 处理器:报告生成
|
||||
|
||||
将原 WaterQualityInversionPipeline.generate_pipeline_report() 方法
|
||||
剥离为独立的 Step14ReportHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class Step14ReportHandler(BaseStepHandler):
|
||||
"""步骤14:流程执行报告生成。
|
||||
|
||||
对应 config key: 'step14_report'
|
||||
生成 CSV 和 TXT 格式的流程执行报告。
|
||||
"""
|
||||
|
||||
step_key = 'step14_report'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
try:
|
||||
output_path = config.get('output_path')
|
||||
if output_path is None:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
output_path = str(context.reports_dir / f"pipeline_report_{timestamp}.csv")
|
||||
|
||||
report_data = []
|
||||
total_time = 0.0
|
||||
|
||||
step_order = [
|
||||
"步骤1: 水域掩膜生成",
|
||||
"步骤2: 耀斑区域检测",
|
||||
"步骤3: 耀斑去除",
|
||||
"步骤4: 数据预处理",
|
||||
"步骤5: 光谱提取",
|
||||
"步骤6: 水质光谱指数计算",
|
||||
"步骤7: 机器学习建模与训练",
|
||||
"步骤8: 非经验模型训练",
|
||||
"步骤9: 自定义回归",
|
||||
"步骤10: 采样点生成",
|
||||
"步骤11: 参数预测",
|
||||
"步骤12: 分布图生成",
|
||||
]
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in context.step_timings:
|
||||
timing_info = context.step_timings[step_name]
|
||||
report_data.append({
|
||||
'步骤': step_name,
|
||||
'开始时间': timing_info['start_time'],
|
||||
'结束时间': timing_info['end_time'],
|
||||
'耗时(秒)': f"{timing_info['elapsed_seconds']:.2f}",
|
||||
'耗时(格式化)': timing_info['elapsed_formatted'],
|
||||
'状态': timing_info['status'],
|
||||
'错误信息': timing_info.get('error', '')
|
||||
})
|
||||
if timing_info['status'] == 'completed':
|
||||
total_time += timing_info['elapsed_seconds']
|
||||
|
||||
if context.pipeline_start_time and context.pipeline_end_time:
|
||||
pipeline_total = context.pipeline_end_time - context.pipeline_start_time
|
||||
report_data.append({
|
||||
'步骤': '总计',
|
||||
'开始时间': datetime.fromtimestamp(context.pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'结束时间': datetime.fromtimestamp(context.pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'耗时(秒)': f"{pipeline_total:.2f}",
|
||||
'耗时(格式化)': context._format_time(pipeline_total),
|
||||
'状态': 'completed',
|
||||
'错误信息': ''
|
||||
})
|
||||
|
||||
df_report = pd.DataFrame(report_data)
|
||||
df_report.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
txt_output_path = str(Path(output_path).with_suffix('.txt'))
|
||||
with open(txt_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write("=" * 80 + "\n")
|
||||
f.write("水质参数反演流程执行报告\n")
|
||||
f.write("=" * 80 + "\n\n")
|
||||
|
||||
if context.pipeline_start_time and context.pipeline_end_time:
|
||||
f.write(f"流程开始时间: {datetime.fromtimestamp(context.pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"流程结束时间: {datetime.fromtimestamp(context.pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"总耗时: {context._format_time(context.pipeline_end_time - context.pipeline_start_time)}\n\n")
|
||||
|
||||
f.write("-" * 80 + "\n")
|
||||
f.write("各步骤执行详情:\n")
|
||||
f.write("-" * 80 + "\n\n")
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in context.step_timings:
|
||||
timing_info = context.step_timings[step_name]
|
||||
f.write(f"{step_name}\n")
|
||||
f.write(f" 开始时间: {timing_info['start_time']}\n")
|
||||
f.write(f" 结束时间: {timing_info['end_time']}\n")
|
||||
f.write(f" 耗时: {timing_info['elapsed_formatted']} ({timing_info['elapsed_seconds']:.2f}秒)\n")
|
||||
f.write(f" 状态: {timing_info['status']}\n")
|
||||
if timing_info.get('error'):
|
||||
f.write(f" 错误: {timing_info['error']}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write("-" * 80 + "\n")
|
||||
f.write("统计摘要:\n")
|
||||
f.write("-" * 80 + "\n")
|
||||
completed_steps = [s for s in context.step_timings.values() if s['status'] == 'completed']
|
||||
failed_steps = [s for s in context.step_timings.values() if s['status'] == 'failed']
|
||||
skipped_steps = [s for s in context.step_timings.values() if s['status'] == 'skipped']
|
||||
f.write(f"成功完成的步骤: {len(completed_steps)}\n")
|
||||
f.write(f"失败的步骤: {len(failed_steps)}\n")
|
||||
f.write(f"跳过的步骤: {len(skipped_steps)}\n")
|
||||
if completed_steps:
|
||||
completed_times = [s['elapsed_seconds'] for s in completed_steps]
|
||||
f.write(f"平均耗时: {context._format_time(np.mean(completed_times))}\n")
|
||||
f.write(f"最长耗时: {context._format_time(np.max(completed_times))}\n")
|
||||
f.write(f"最短耗时: {context._format_time(np.min(completed_times))}\n")
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤14: 报告生成", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'report_csv': output_path, 'report_txt': txt_output_path}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤14: 报告生成", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
58
src/core/handlers/step8_ml_train.py
Normal file
58
src/core/handlers/step8_ml_train.py
Normal file
@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8 处理器:机器学习建模与训练
|
||||
|
||||
将原 WaterQualityInversionPipeline.step8_train_ml() 方法
|
||||
剥离为独立的 Step8MlTrainHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.modeling_step import ModelingStep
|
||||
|
||||
|
||||
class Step8MlTrainHandler(BaseStepHandler):
|
||||
"""步骤8:机器学习建模与训练。
|
||||
|
||||
对应 config key: 'step8_ml_train'
|
||||
委托类: ModelingStep.train_models()
|
||||
"""
|
||||
|
||||
step_key = 'step8_ml_train'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
training_csv_path = self._resolve_path(
|
||||
config.get('training_csv_path'), context.training_csv_path, 'training_csv'
|
||||
)
|
||||
|
||||
try:
|
||||
result = ModelingStep.train_models(
|
||||
feature_start_column=config.get('feature_start_column', '374.285004'),
|
||||
preprocessing_methods=config.get('preprocessing_methods'),
|
||||
model_names=config.get('model_names'),
|
||||
split_methods=config.get('split_methods'),
|
||||
cv_folds=config.get('cv_folds', 5),
|
||||
training_csv_path=training_csv_path,
|
||||
output_dir=str(context.models_dir),
|
||||
_report_generator=context.report_generator,
|
||||
)
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤8: 机器学习建模与训练", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'models_dir': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤8: 机器学习建模与训练", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
64
src/core/handlers/step9_ml_predict.py
Normal file
64
src/core/handlers/step9_ml_predict.py
Normal file
@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step9 处理器:机器学习推理预测
|
||||
|
||||
将原 WaterQualityInversionPipeline.step9_predict_ml() 方法
|
||||
剥离为独立的 Step9MlPredictHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.prediction_step import PredictionStep
|
||||
|
||||
|
||||
class Step9MlPredictHandler(BaseStepHandler):
|
||||
"""步骤9:机器学习推理预测。
|
||||
|
||||
对应 config key: 'step9_ml_predict'
|
||||
委托类: PredictionStep.predict_water_quality()
|
||||
"""
|
||||
|
||||
step_key = 'step9_ml_predict'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
sampling_csv_path = self._resolve_path(
|
||||
config.get('sampling_csv_path'), context.sampling_csv_path, 'sampling_csv'
|
||||
)
|
||||
|
||||
models_dir = config.get('models_dir') or str(context.models_dir)
|
||||
|
||||
try:
|
||||
result = PredictionStep.predict_water_quality(
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
models_dir=models_dir,
|
||||
metric=config.get('metric', 'test_r2'),
|
||||
prediction_column=config.get('prediction_column', 'prediction'),
|
||||
output_dir=str(context.prediction_dir / "9_ML_Prediction"),
|
||||
_report_generator=context.report_generator,
|
||||
_external_model=config.get('_external_model'),
|
||||
_external_model_path=config.get('_external_model_path'),
|
||||
_external_models_dict=config.get('_external_models_dict'),
|
||||
_external_model_dir=config.get('_external_model_dir'),
|
||||
)
|
||||
|
||||
context.prediction_files.update(result)
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤9: 机器学习推理预测", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'prediction_files': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤9: 机器学习推理预测", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user