# -*- coding: utf-8 -*- """ 预测步骤 包含 step7_generate_sampling_points, step8_predict_water_quality, step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression """ import time from pathlib import Path from typing import Optional, List, Union, Callable, Dict class PredictionStep: """预测步骤""" # ---- Step 7: 生成采样点并提取光谱 ---- @staticmethod def generate_sampling_points( deglint_img_path: Optional[str] = None, interval: int = 50, sample_radius: int = 5, chunk_size: int = 1000, water_mask_path: Optional[str] = None, glint_mask_path: Optional[str] = None, output_dir: Union[str, Path] = "./4_sampling", callback: Optional[Callable] = None, use_adaptive_sampling: bool = True, ) -> str: """生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱""" from pathlib import Path from src.utils.sampling import get_spectral_sampling_points_chunked output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) output_path = str(output_dir / "sampling_spectra.csv") def notify(status, msg=""): if callback: callback("步骤7", status, msg) print("\n" + "=" * 80) print("步骤7: 生成预测采样点并提取光谱") print("=" * 80) step_start_time = time.time() if deglint_img_path is None: raise ValueError("必须提供 deglint_img_path 参数") # 1. 初始归一化与安全转换 original_path = Path(deglint_img_path) final_deglint_path = original_path # 2. 智能回溯探测:如果当前路径不存在,或者后缀是前端死板的 .dat if not final_deglint_path.exists() or final_deglint_path.suffix.lower() == '.dat': print(f"🔍 智能探测:输入去耀斑路径不存在或为 .dat 占位符 ({final_deglint_path}),正在向上搜索真实产物...") # 定位到预期的 3_deglint 根目录 possible_dir = original_path.parent if possible_dir.name != '3_deglint' and Path(output_path).parent.parent.exists(): possible_dir = Path(output_path).parent.parent / "3_deglint" if possible_dir.exists(): # 搜寻该目录下所有真实存在的 .bsq 文件(接管 goodman/sugar/kutser/hedley 的硬编码产物) existing_bsqs = list(possible_dir.glob("*.bsq")) if existing_bsqs: final_deglint_path = existing_bsqs[0] print(f"💡 智能拦截成功:自动寻回底层真实去耀斑影像: {final_deglint_path}") else: final_deglint_path = original_path.with_suffix('.bsq') else: final_deglint_path = original_path.with_suffix('.bsq') deglint_img_str = str(final_deglint_path) if Path(output_path).exists(): print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}") notify("skipped", f"采样点光谱数据已设置: {output_path}") return output_path glint_mask_to_use = glint_mask_path if glint_mask_to_use is None: print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。") # 传递极度安全的 deglint_img_str 进底层(关键字传参,避免 positional 参数顺序陷阱) get_spectral_sampling_points_chunked( deglint_img_str, water_mask_path, glint_mask_to_use, output_path, interval=interval, sample_radius=sample_radius, chunk_size=chunk_size, use_adaptive_sampling=use_adaptive_sampling, ) notify("completed", f"采样点光谱数据已保存: {output_path}") return output_path # ---- Step 8: 机器学习模型预测水质参数 ---- @staticmethod def predict_water_quality( sampling_csv_path: str, models_dir: Optional[str] = None, metric: str = "test_r2", prediction_column: str = "prediction", output_dir: Union[str, Path] = "./9_ML_Prediction", callback: Optional[Callable] = None, _report_generator=None, _external_model=None, _external_model_path=None, _external_models_dict=None, _external_model_dir=None, ) -> Dict[str, str]: """将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数""" from src.core.prediction.inference_batch import WaterQualityInference def notify(status, msg=""): if callback: callback("步骤8", status, msg) print("\n" + "=" * 80) print("步骤8: 预测水质参数") print("=" * 80) print(f"[PredictionStep] 准备执行预测,字典状态: {'Yes' if _external_models_dict else 'No'}" f", 单模型状态: {'Yes' if _external_model else 'No'}") step_start_time = time.time() if models_dir is None: raise ValueError("必须提供 models_dir 参数") ml_prediction_dir = Path(output_dir) ml_prediction_dir.mkdir(parents=True, exist_ok=True) prediction_files = {} if ml_prediction_dir.exists(): csv_files = list(ml_prediction_dir.glob("*.csv")) for csv_file in csv_files: file_stem = csv_file.stem if "_prediction" in file_stem: target_name = file_stem.replace("_prediction", "") elif "_pred" in file_stem: target_name = file_stem.replace("_pred", "") else: target_name = file_stem prediction_files[target_name] = str(csv_file) # 检查是否所有目标参数都有预测文件 if prediction_files: models_path_obj = Path(models_dir) if models_path_obj.exists(): target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()] missing_targets = [t for t in target_folders if t not in prediction_files] if not missing_targets: print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}") notify("skipped", f"预测结果已设置: {ml_prediction_dir}") return prediction_files else: print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...") all_results = {} if _external_models_dict: # 外部模型字典优先:直接用字典的 keys 作为 targets 列表, # 手动为每个模型创建 inference 实例并调用 inference_pipeline。 print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...") for target_name, model_obj in _external_models_dict.items(): try: output_file = ml_prediction_dir / f"{target_name}.csv" model_inferencer = WaterQualityInference( models_dir or "./", external_model=model_obj, external_model_path=_external_model_dir or "", ) predictions, result_df = model_inferencer.inference_pipeline( sampling_csv_path=sampling_csv_path, output_csv_path=str(output_file), metric=metric, prediction_column=prediction_column, ) prediction_files[target_name] = str(output_file) all_results[target_name] = { "status": "success", "output_file": str(output_file), "sample_count": len(predictions), } print(f" ✓ {target_name}: {len(predictions)} 个预测值") except Exception as e: print(f" ✗ {target_name}: 失败 — {type(e).__name__}: {e}") prediction_files[target_name] = None all_results[target_name] = {"status": "error", "error": str(e)} else: # 字典为空或不存在:回退到扫描 models_dir 子目录的传统逻辑 inferencer = WaterQualityInference( models_dir, external_model=_external_model, external_model_path=_external_model_path, ) all_results = inferencer.batch_inference_multi_models( models_root_dir=models_dir, sampling_csv_path=sampling_csv_path, output_dir=str(ml_prediction_dir), metric=metric, prediction_column=prediction_column, output_format="csv", external_model=_external_model, external_model_path=_external_model_path, external_models_dict=_external_models_dict, ) # batch_inference_multi_models 已确保返回字典,永不返回 None if all_results: for target_name, result in all_results.items(): if result.get("status") == "success": prediction_files[target_name] = result["output_file"] print(f"预测完成,结果保存在: {ml_prediction_dir}") if _report_generator is not None: try: report_path = _report_generator.generate_prediction_report(prediction_files) print(f"预测结果报告已生成: {report_path}") except Exception as e: print(f"生成预测结果报告时出错: {e}") notify("completed", f"预测完成: {ml_prediction_dir}") return prediction_files # ---- Step 8.5: 非经验模型预测 ---- @staticmethod def predict_with_non_empirical_models( sampling_csv_path: str, non_empirical_models_dir: Optional[str] = None, output_dir: Optional[str] = None, metric: str = "Average Accuracy(%)", prediction_column: str = "prediction", enabled: bool = True, callback: Optional[Callable] = None, work_dir: Union[str, Path] = "./work_dir", ) -> Dict[str, str]: """使用非经验统计回归模型进行参数预测""" def notify(status, msg=""): if callback: callback("步骤8.5", status, msg) print("\n" + "=" * 80) print("步骤8.5: 使用非经验模型进行参数预测") print("=" * 80) step_start_time = time.time() if not enabled: print("已设置跳过非经验模型预测(enabled=False)。") notify("skipped", "跳过非经验模型预测") return {} if non_empirical_models_dir is not None: final_models_dir = non_empirical_models_dir else: default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression") if Path(default_models_dir).exists(): final_models_dir = default_models_dir else: raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数") if output_dir is not None: non_empirical_prediction_dir = Path(output_dir) else: non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction" non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True) prediction_files = {} summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv" if not summary_path.exists(): raise ValueError(f"未找到非经验模型汇总文件: {summary_path}") import pandas as pd df_summary = pd.read_csv(summary_path) best_models = {} for algorithm in df_summary["Algorithm Name"].unique(): algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm] if metric in algorithm_df.columns: best_model_row = algorithm_df.nlargest(1, metric) else: best_model_row = algorithm_df.iloc[[0]] best_model_path = best_model_row["Model File"].values[0] best_preprocess = best_model_row["Preprocessing Method"].values[0] best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A" best_models[algorithm] = { "model_path": best_model_path, "preprocess_method": best_preprocess, "accuracy": best_accuracy, } print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})") pd.read_csv(sampling_csv_path) # just to validate for algorithm, model_info in best_models.items(): print(f"\n使用 {algorithm} 算法进行预测...") output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv") if Path(output_path).exists(): print(f"检测到已存在的预测结果文件,直接使用: {output_path}") prediction_files[algorithm] = output_path continue try: from src.core.non_empirical_retrieval import non_empirical_retrieval non_empirical_retrieval( algorithm=algorithm, model_info_path=model_info["model_path"], coor_spectral_path=sampling_csv_path, output_path=output_path, wave_radius=5, ) prediction_files[algorithm] = output_path print(f"预测完成: {output_path}") except Exception as e: print(f"使用 {algorithm} 算法预测时出错: {e}") continue notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}") return prediction_files # ---- Step 8.75: 自定义回归模型预测 ---- @staticmethod def predict_with_custom_regression( sampling_csv_path: str, custom_regression_dir: Optional[str] = None, formula_csv_path: Optional[str] = None, coordinate_columns: Optional[List[str]] = None, output_dir: Optional[str] = None, filename_prefix: str = "custom_regression_prediction", enabled: bool = True, callback: Optional[Callable] = None, work_dir: Union[str, Path] = "./work_dir", ) -> Dict[str, str]: """使用自定义回归模型进行参数预测""" def notify(status, msg=""): if callback: callback("步骤8.75", status, msg) print("\n" + "=" * 80) print("步骤8.75: 使用自定义回归模型进行参数预测") print("=" * 80) step_start_time = time.time() if not enabled: print("已设置跳过自定义回归模型预测(enabled=False)。") notify("skipped", "跳过自定义回归预测") return {} if not Path(sampling_csv_path).exists(): raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}") if custom_regression_dir is not None: final_regression_dir = custom_regression_dir else: final_regression_dir = str(Path(work_dir) / "13_Custom_Regression") if not Path(final_regression_dir).exists(): raise ValueError( "请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数" ) if output_dir is None: custom_regression_prediction_dir = Path(work_dir) / "13_Custom_Regression" / "Custom_Regression_Prediction" custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True) prediction_output_dir = str(custom_regression_prediction_dir) else: prediction_output_dir = output_dir from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor predictor = CustomRegressionPredictor( regression_csv_dir=final_regression_dir, formula_csv_path=formula_csv_path, ) print(f"开始使用自定义回归模块进行批量预测...") print(f" 采样点数据: {sampling_csv_path}") print(f" 回归模型目录: {final_regression_dir}") print(f" 输出目录: {prediction_output_dir}") saved_files = predictor.run_batch_prediction( input_csv_path=sampling_csv_path, coordinate_columns=coordinate_columns, filename_prefix=filename_prefix, ) print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:") for param_name, filepath in saved_files.items(): print(f" {param_name}: {filepath}") notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件") return saved_files