refactor(step4): 剥离 Steps 层 - step4~step9 业务逻辑下沉到独立模块
This commit is contained in:
323
src/core/steps/prediction_step.py
Normal file
323
src/core/steps/prediction_step.py
Normal file
@ -0,0 +1,323 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
预测步骤
|
||||
|
||||
包含 step7_generate_sampling_points, step8_predict_water_quality,
|
||||
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
|
||||
class PredictionStep:
|
||||
"""预测步骤"""
|
||||
|
||||
# ---- Step 7: 生成采样点并提取光谱 ----
|
||||
|
||||
@staticmethod
|
||||
def generate_sampling_points(
|
||||
deglint_img_path: Optional[str] = None,
|
||||
interval: int = 50,
|
||||
sample_radius: int = 5,
|
||||
chunk_size: int = 1000,
|
||||
water_mask_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./10_sampling",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
|
||||
from src.utils.sampling import get_spectral_sampling_points_chunked
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "sampling_spectra.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤7", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤7: 生成预测采样点并提取光谱")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if deglint_img_path is None:
|
||||
raise ValueError("必须提供 deglint_img_path 参数")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
|
||||
notify("skipped", f"采样点光谱数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
# 允许外部显式传入 glint_mask_path 覆盖内部默认值
|
||||
glint_mask_to_use = glint_mask_path
|
||||
if glint_mask_to_use is None:
|
||||
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
|
||||
|
||||
get_spectral_sampling_points_chunked(
|
||||
deglint_img_path, water_mask_path, glint_mask_to_use,
|
||||
output_path, interval, sample_radius, chunk_size
|
||||
)
|
||||
|
||||
notify("completed", f"采样点光谱数据已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 8: 机器学习模型预测水质参数 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_water_quality(
|
||||
sampling_csv_path: str,
|
||||
models_dir: Optional[str] = None,
|
||||
metric: str = "test_r2",
|
||||
prediction_column: str = "prediction",
|
||||
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
) -> Dict[str, str]:
|
||||
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8: 预测水质参数")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if models_dir is None:
|
||||
raise ValueError("必须提供 models_dir 参数")
|
||||
|
||||
ml_prediction_dir = Path(output_dir)
|
||||
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prediction_files = {}
|
||||
if ml_prediction_dir.exists():
|
||||
csv_files = list(ml_prediction_dir.glob("*.csv"))
|
||||
for csv_file in csv_files:
|
||||
file_stem = csv_file.stem
|
||||
if "_prediction" in file_stem:
|
||||
target_name = file_stem.replace("_prediction", "")
|
||||
elif "_pred" in file_stem:
|
||||
target_name = file_stem.replace("_pred", "")
|
||||
else:
|
||||
target_name = file_stem
|
||||
prediction_files[target_name] = str(csv_file)
|
||||
|
||||
# 检查是否所有目标参数都有预测文件
|
||||
if prediction_files:
|
||||
models_path_obj = Path(models_dir)
|
||||
if models_path_obj.exists():
|
||||
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
|
||||
missing_targets = [t for t in target_folders if t not in prediction_files]
|
||||
if not missing_targets:
|
||||
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
|
||||
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
|
||||
return prediction_files
|
||||
else:
|
||||
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||||
|
||||
inferencer = WaterQualityInference(models_dir)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_dir=str(ml_prediction_dir),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
output_format="csv",
|
||||
)
|
||||
|
||||
for target_name, result in all_results.items():
|
||||
if result.get("status") == "success":
|
||||
prediction_files[target_name] = result["output_file"]
|
||||
|
||||
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||||
|
||||
if _report_generator is not None:
|
||||
try:
|
||||
report_path = _report_generator.generate_prediction_report(prediction_files)
|
||||
print(f"预测结果报告已生成: {report_path}")
|
||||
except Exception as e:
|
||||
print(f"生成预测结果报告时出错: {e}")
|
||||
|
||||
notify("completed", f"预测完成: {ml_prediction_dir}")
|
||||
return prediction_files
|
||||
|
||||
# ---- Step 8.5: 非经验模型预测 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_with_non_empirical_models(
|
||||
sampling_csv_path: str,
|
||||
non_empirical_models_dir: Optional[str] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
metric: str = "Average Accuracy(%)",
|
||||
prediction_column: str = "prediction",
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Dict[str, str]:
|
||||
"""使用非经验统计回归模型进行参数预测"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8.5: 使用非经验模型进行参数预测")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过非经验模型预测(enabled=False)。")
|
||||
notify("skipped", "跳过非经验模型预测")
|
||||
return {}
|
||||
|
||||
if non_empirical_models_dir is not None:
|
||||
final_models_dir = non_empirical_models_dir
|
||||
else:
|
||||
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
|
||||
if Path(default_models_dir).exists():
|
||||
final_models_dir = default_models_dir
|
||||
else:
|
||||
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
|
||||
|
||||
if output_dir is not None:
|
||||
non_empirical_prediction_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
|
||||
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prediction_files = {}
|
||||
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
|
||||
if not summary_path.exists():
|
||||
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
|
||||
|
||||
import pandas as pd
|
||||
df_summary = pd.read_csv(summary_path)
|
||||
|
||||
best_models = {}
|
||||
for algorithm in df_summary["Algorithm Name"].unique():
|
||||
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
|
||||
if metric in algorithm_df.columns:
|
||||
best_model_row = algorithm_df.nlargest(1, metric)
|
||||
else:
|
||||
best_model_row = algorithm_df.iloc[[0]]
|
||||
|
||||
best_model_path = best_model_row["Model File"].values[0]
|
||||
best_preprocess = best_model_row["Preprocessing Method"].values[0]
|
||||
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
|
||||
|
||||
best_models[algorithm] = {
|
||||
"model_path": best_model_path,
|
||||
"preprocess_method": best_preprocess,
|
||||
"accuracy": best_accuracy,
|
||||
}
|
||||
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
|
||||
|
||||
pd.read_csv(sampling_csv_path) # just to validate
|
||||
|
||||
for algorithm, model_info in best_models.items():
|
||||
print(f"\n使用 {algorithm} 算法进行预测...")
|
||||
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
|
||||
prediction_files[algorithm] = output_path
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.core.non_empirical_retrieval import non_empirical_retrieval
|
||||
non_empirical_retrieval(
|
||||
algorithm=algorithm,
|
||||
model_info_path=model_info["model_path"],
|
||||
coor_spectral_path=sampling_csv_path,
|
||||
output_path=output_path,
|
||||
wave_radius=5,
|
||||
)
|
||||
prediction_files[algorithm] = output_path
|
||||
print(f"预测完成: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"使用 {algorithm} 算法预测时出错: {e}")
|
||||
continue
|
||||
|
||||
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
|
||||
return prediction_files
|
||||
|
||||
# ---- Step 8.75: 自定义回归模型预测 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_with_custom_regression(
|
||||
sampling_csv_path: str,
|
||||
custom_regression_dir: Optional[str] = None,
|
||||
formula_csv_path: Optional[str] = None,
|
||||
coordinate_columns: Optional[List[str]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
filename_prefix: str = "custom_regression_prediction",
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Dict[str, str]:
|
||||
"""使用自定义回归模型进行参数预测"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8.75", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8.75: 使用自定义回归模型进行参数预测")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过自定义回归模型预测(enabled=False)。")
|
||||
notify("skipped", "跳过自定义回归预测")
|
||||
return {}
|
||||
|
||||
if not Path(sampling_csv_path).exists():
|
||||
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
|
||||
|
||||
if custom_regression_dir is not None:
|
||||
final_regression_dir = custom_regression_dir
|
||||
else:
|
||||
final_regression_dir = str(Path(work_dir) / "9_Custom_Regression_Modeling")
|
||||
if not Path(final_regression_dir).exists():
|
||||
raise ValueError(
|
||||
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
|
||||
)
|
||||
|
||||
if output_dir is None:
|
||||
custom_regression_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Custom_Regression_Prediction"
|
||||
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
prediction_output_dir = str(custom_regression_prediction_dir)
|
||||
else:
|
||||
prediction_output_dir = output_dir
|
||||
|
||||
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
|
||||
|
||||
predictor = CustomRegressionPredictor(
|
||||
regression_csv_dir=final_regression_dir,
|
||||
formula_csv_path=formula_csv_path,
|
||||
)
|
||||
|
||||
print(f"开始使用自定义回归模块进行批量预测...")
|
||||
print(f" 采样点数据: {sampling_csv_path}")
|
||||
print(f" 回归模型目录: {final_regression_dir}")
|
||||
print(f" 输出目录: {prediction_output_dir}")
|
||||
|
||||
saved_files = predictor.run_batch_prediction(
|
||||
input_csv_path=sampling_csv_path,
|
||||
coordinate_columns=coordinate_columns,
|
||||
filename_prefix=filename_prefix,
|
||||
)
|
||||
|
||||
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
|
||||
for param_name, filepath in saved_files.items():
|
||||
print(f" {param_name}: {filepath}")
|
||||
|
||||
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
|
||||
return saved_files
|
||||
Reference in New Issue
Block a user