Compare commits
11 Commits
605ec86108
...
5a55be286f
| Author | SHA1 | Date | |
|---|---|---|---|
| 5a55be286f | |||
| 9ba39a7bff | |||
| d15a7a1e2b | |||
| 6d4d802ffe | |||
| abac272b31 | |||
| 95d30d8d81 | |||
| 375fea77b9 | |||
| 8c7c995985 | |||
| f96c55f361 | |||
| 14278739bf | |||
| d0eb458392 |
20
src/core/steps/__init__.py
Normal file
20
src/core/steps/__init__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""业务步骤层模块"""
|
||||||
|
|
||||||
|
from src.core.steps.water_mask_step import WaterMaskStep
|
||||||
|
from src.core.steps.glint_detection_step import GlintDetectionStep
|
||||||
|
from src.core.steps.glint_removal_step import GlintRemovalStep
|
||||||
|
from src.core.steps.data_preparation_step import DataPreparationStep
|
||||||
|
from src.core.steps.modeling_step import ModelingStep
|
||||||
|
from src.core.steps.prediction_step import PredictionStep
|
||||||
|
from src.core.steps.mapping_step import MappingStep
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"WaterMaskStep",
|
||||||
|
"GlintDetectionStep",
|
||||||
|
"GlintRemovalStep",
|
||||||
|
"DataPreparationStep",
|
||||||
|
"ModelingStep",
|
||||||
|
"PredictionStep",
|
||||||
|
"MappingStep",
|
||||||
|
]
|
||||||
184
src/core/steps/data_preparation_step.py
Normal file
184
src/core/steps/data_preparation_step.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
数据准备步骤
|
||||||
|
|
||||||
|
包含 step4_process_csv, step5_extract_training_spectra, step5_5_calculate_water_quality_indices
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Union, Callable, Dict
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class DataPreparationStep:
|
||||||
|
"""数据准备步骤"""
|
||||||
|
|
||||||
|
# ---- Step 4: 处理CSV文件 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_csv(
|
||||||
|
csv_path: str,
|
||||||
|
output_dir: Union[str, Path] = "./4_processed_data",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""处理CSV文件(筛选剔除异常值)"""
|
||||||
|
from src.preprocessing.process_water_quality_data import process_water_quality_data
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = str(output_dir / "processed_data.csv")
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤4", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤4: 处理CSV文件,筛选剔除异常值")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的处理后CSV文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"处理后的CSV文件已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
process_water_quality_data(csv_path, output_path)
|
||||||
|
notify("completed", f"处理后的CSV文件已保存: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# ---- Step 5: 提取训练样本点光谱 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_training_spectra(
|
||||||
|
deglint_img_path: Optional[str] = None,
|
||||||
|
radius: int = 5,
|
||||||
|
source_epsg: int = 4326,
|
||||||
|
csv_path: Optional[str] = None,
|
||||||
|
boundary_path: Optional[str] = None,
|
||||||
|
glint_mask_path: Optional[str] = None,
|
||||||
|
water_mask_path: Optional[str] = None,
|
||||||
|
output_dir: Union[str, Path] = "./5_training_spectra",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
||||||
|
from src.core.glint_removal.get_spectral import get_spectral_in_coor
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = str(output_dir / "training_spectra.csv")
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤5", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤5: 提取训练样本点的平均光谱")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if deglint_img_path is None:
|
||||||
|
raise ValueError("必须提供 deglint_img_path 参数")
|
||||||
|
if csv_path is None:
|
||||||
|
raise ValueError("必须提供 csv_path 参数")
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的训练光谱数据文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"训练光谱数据已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# 确保水体掩膜存在
|
||||||
|
final_boundary_path = boundary_path
|
||||||
|
if final_boundary_path is None and water_mask_path is not None:
|
||||||
|
final_boundary_path = water_mask_path
|
||||||
|
|
||||||
|
# 【新增安全防护】智能拦截矢量 .shp,强制替换为步骤 1 生成的栅格 .dat
|
||||||
|
if final_boundary_path and str(final_boundary_path).lower().endswith('.shp'):
|
||||||
|
# 向上追溯查找 1_water_mask 目录下的 dat 替身
|
||||||
|
possible_dat = Path(deglint_img_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
|
||||||
|
if not possible_dat.exists() and output_path:
|
||||||
|
possible_dat = Path(output_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
|
||||||
|
|
||||||
|
if possible_dat.exists():
|
||||||
|
print(f"💡 智能拦截:检测到输入掩膜为矢量 .shp,自动切换为已生成的栅格掩膜: {possible_dat}")
|
||||||
|
final_boundary_path = str(possible_dat)
|
||||||
|
else:
|
||||||
|
print(f"⚠️ 警告:检测到输入掩膜为 .shp 且未找到对应 .dat 替身,可能导致底层读取失败。")
|
||||||
|
|
||||||
|
flare_path = glint_mask_path
|
||||||
|
if flare_path:
|
||||||
|
print(f"光谱提取使用耀斑掩膜: {flare_path}")
|
||||||
|
|
||||||
|
get_spectral_in_coor(
|
||||||
|
deglint_img_path, csv_path, output_path,
|
||||||
|
radius=radius, flare_path=flare_path,
|
||||||
|
boundary_path=final_boundary_path, source_epsg=source_epsg
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"训练光谱数据已保存: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# ---- Step 5.5: 计算水质光谱指数 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_water_quality_indices(
|
||||||
|
training_spectra_path: Optional[str] = None,
|
||||||
|
formula_csv_file: Optional[str] = None,
|
||||||
|
formula_names: Optional[List[str]] = None,
|
||||||
|
output_file: Optional[str] = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
output_dir: Union[str, Path] = "./6_water_quality_indices",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤5.5", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤5.5: 计算水质光谱指数(使用band_math方法)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过水质指数计算(enabled=False)。")
|
||||||
|
notify("skipped", "跳过水质指数计算")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if training_spectra_path is None:
|
||||||
|
raise ValueError("必须提供 training_spectra_path 参数")
|
||||||
|
if formula_csv_file is None:
|
||||||
|
raise ValueError("必须提供 formula_csv_file 参数")
|
||||||
|
|
||||||
|
if output_file:
|
||||||
|
output_path = str(Path(output_file))
|
||||||
|
else:
|
||||||
|
output_path = str(output_dir / "water_quality_indices.csv")
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的水质指数文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"水质指数数据已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
from src.utils.band_math import BandMathCalculator
|
||||||
|
|
||||||
|
calculator = BandMathCalculator(training_spectra_path)
|
||||||
|
result_df = calculator.process_formulas_from_csv(
|
||||||
|
formula_csv_file=formula_csv_file,
|
||||||
|
formula_names=formula_names,
|
||||||
|
output_file=output_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if result_df is None:
|
||||||
|
raise ValueError("计算水质指数失败,请检查公式CSV文件格式")
|
||||||
|
|
||||||
|
notify("completed", f"水质指数已保存: {output_path}")
|
||||||
|
return output_path
|
||||||
113
src/core/steps/glint_detection_step.py
Normal file
113
src/core/steps/glint_detection_step.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
步骤2: 耀斑区域检测
|
||||||
|
|
||||||
|
支持多种检测方法: otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
|
|
||||||
|
class GlintDetectionStep:
|
||||||
|
"""耀斑区域检测步骤"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run(
|
||||||
|
img_path: str,
|
||||||
|
glint_wave: float = 750.0,
|
||||||
|
method: str = "otsu",
|
||||||
|
z_threshold: float = 2.5,
|
||||||
|
percentile: float = 95.0,
|
||||||
|
iqr_multiplier: float = 1.5,
|
||||||
|
window_size: int = 15,
|
||||||
|
multi_band_waves: Optional[List[float]] = None,
|
||||||
|
sub_method: str = "zscore",
|
||||||
|
weights: Optional[List[float]] = None,
|
||||||
|
max_area: Optional[int] = None,
|
||||||
|
buffer_size: Optional[int] = None,
|
||||||
|
water_mask_path: Optional[str] = None,
|
||||||
|
glint_dir: Union[str, Path] = "./2_glint",
|
||||||
|
callback: Optional[callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
执行耀斑区域检测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path: 输入影像文件路径
|
||||||
|
glint_wave: 用于耀斑检测的波段波长(nm)
|
||||||
|
method: 检测方法 ('otsu' | 'zscore' | 'percentile' | 'iqr' | 'adaptive' | 'multi_band')
|
||||||
|
z_threshold: Z-score 方法阈值(默认 2.5)
|
||||||
|
percentile: 百分位数阈值(默认 95.0)
|
||||||
|
iqr_multiplier: IQR 倍数(默认 1.5)
|
||||||
|
window_size: 自适应阈值窗口大小(默认 15)
|
||||||
|
multi_band_waves: 多波段方法的波长列表,如 [750, 800, 850]
|
||||||
|
sub_method: 多波段方法的子方法(默认 'zscore')
|
||||||
|
weights: 多波段方法的权重列表(None 表示等权重)
|
||||||
|
max_area: 最大连通域面积阈值(像素),超过则过滤
|
||||||
|
buffer_size: 岸边缓冲区大小(像素),用于去除岸边附近错误掩膜
|
||||||
|
water_mask_path: 水域掩膜文件路径(dat 格式优先)
|
||||||
|
glint_dir: 工作目录
|
||||||
|
callback: 回调函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
耀斑掩膜文件路径 (.dat)
|
||||||
|
"""
|
||||||
|
from src.utils.find_severe_glint_area import find_severe_glint_area
|
||||||
|
|
||||||
|
glint_dir = Path(glint_dir)
|
||||||
|
glint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤2", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤2: 找到耀斑区域")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
# 确定水体掩膜路径
|
||||||
|
if water_mask_path is not None and Path(water_mask_path).exists():
|
||||||
|
final_water_mask_path = water_mask_path
|
||||||
|
else:
|
||||||
|
final_water_mask_path = None
|
||||||
|
|
||||||
|
output_path = str(glint_dir / "severe_glint_area.dat")
|
||||||
|
|
||||||
|
# 跳过已存在的文件
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的耀斑掩膜文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"耀斑掩膜已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# 构建检测参数字典
|
||||||
|
kwargs = {
|
||||||
|
"method": method,
|
||||||
|
"z_threshold": z_threshold,
|
||||||
|
"percentile": percentile,
|
||||||
|
"iqr_multiplier": iqr_multiplier,
|
||||||
|
"window_size": window_size,
|
||||||
|
}
|
||||||
|
if method == "multi_band":
|
||||||
|
if multi_band_waves is not None:
|
||||||
|
kwargs["multi_band_waves"] = multi_band_waves
|
||||||
|
if sub_method is not None:
|
||||||
|
kwargs["sub_method"] = sub_method
|
||||||
|
if weights is not None:
|
||||||
|
kwargs["weights"] = weights
|
||||||
|
if max_area is not None:
|
||||||
|
kwargs["max_area"] = max_area
|
||||||
|
if buffer_size is not None:
|
||||||
|
kwargs["buffer_size"] = buffer_size
|
||||||
|
|
||||||
|
glint_mask_path = find_severe_glint_area(
|
||||||
|
img_path, final_water_mask_path, glint_wave, output_path, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"耀斑掩膜已生成: {glint_mask_path}")
|
||||||
|
print(f"使用检测方法: {method}")
|
||||||
|
notify("completed", f"耀斑掩膜已生成: {glint_mask_path}")
|
||||||
|
return glint_mask_path
|
||||||
375
src/core/steps/glint_removal_step.py
Normal file
375
src/core/steps/glint_removal_step.py
Normal file
@ -0,0 +1,375 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
步骤3: 去除耀斑
|
||||||
|
|
||||||
|
支持多种方法: subtract_nir, regression_slope, oxygen_absorption, kutser, goodman, hedley, sugar
|
||||||
|
|
||||||
|
每种方法都会:
|
||||||
|
1. 准备水域掩膜(支持 shp 自动转 dat)
|
||||||
|
2. 调用对应的算法类执行处理
|
||||||
|
3. 复制 hdr 文件到输出影像
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Union, Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_rename(src_bsq: str, src_hdr: str, dest_bsq: str, dest_hdr: str) -> str:
|
||||||
|
"""将底层硬编码生成的 .bsq + .hdr 文件对重命名到用户指定的 output_path
|
||||||
|
|
||||||
|
使用 os.remove + os.rename 确保原子覆盖(不等 os.replace 的跨设备行为),
|
||||||
|
resolve() 断路防止同路径 self-rename 报错。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dest_bsq 路径
|
||||||
|
"""
|
||||||
|
src_bsq_p = Path(src_bsq)
|
||||||
|
src_hdr_p = Path(src_hdr)
|
||||||
|
dest_bsq_p = Path(dest_bsq)
|
||||||
|
dest_hdr_p = Path(dest_hdr)
|
||||||
|
|
||||||
|
if str(src_bsq_p.resolve()) == str(dest_bsq_p.resolve()):
|
||||||
|
return dest_bsq
|
||||||
|
|
||||||
|
if dest_bsq_p.exists():
|
||||||
|
os.remove(dest_bsq_p)
|
||||||
|
if dest_hdr_p.exists():
|
||||||
|
os.remove(dest_hdr_p)
|
||||||
|
|
||||||
|
if src_bsq_p.exists():
|
||||||
|
os.rename(src_bsq_p, dest_bsq_p)
|
||||||
|
if src_hdr_p.exists():
|
||||||
|
os.rename(src_hdr_p, dest_hdr_p)
|
||||||
|
|
||||||
|
return dest_bsq
|
||||||
|
|
||||||
|
|
||||||
|
class GlintRemovalStep:
|
||||||
|
"""去除耀斑步骤"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run(
|
||||||
|
img_path: str,
|
||||||
|
method: str = "subtract_nir",
|
||||||
|
start_wave: Optional[float] = None,
|
||||||
|
end_wave: Optional[float] = None,
|
||||||
|
json_path: Optional[str] = None,
|
||||||
|
left_shoulder_wave: Optional[float] = None,
|
||||||
|
valley_wave: Optional[float] = None,
|
||||||
|
right_shoulder_wave: Optional[float] = None,
|
||||||
|
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||||
|
interpolated_img_path: Optional[str] = None,
|
||||||
|
interpolate_zeros: bool = False,
|
||||||
|
interpolation_method: str = "nearest",
|
||||||
|
enabled: bool = True,
|
||||||
|
# Kutser 参数
|
||||||
|
kutser_shp_path: Optional[str] = None,
|
||||||
|
oxy_band: int = 38,
|
||||||
|
lower_oxy: int = 36,
|
||||||
|
upper_oxy: int = 49,
|
||||||
|
nir_band: int = 47,
|
||||||
|
# Goodman 参数
|
||||||
|
nir_lower: int = 25,
|
||||||
|
nir_upper: int = 37,
|
||||||
|
goodman_A: float = 0.000019,
|
||||||
|
goodman_B: float = 0.1,
|
||||||
|
# Hedley 参数
|
||||||
|
hedley_shp_path: Optional[str] = None,
|
||||||
|
hedley_nir_band: int = 47,
|
||||||
|
# SUGAR 参数
|
||||||
|
sugar_bounds: Optional[List[tuple]] = None,
|
||||||
|
sugar_sigma: float = 1.0,
|
||||||
|
sugar_estimate_background: bool = True,
|
||||||
|
sugar_glint_mask_method: str = "cdf",
|
||||||
|
sugar_iter: Optional[int] = 3,
|
||||||
|
sugar_termination_thresh: float = 20.0,
|
||||||
|
# 内部工具函数
|
||||||
|
_get_image_geo_info=None,
|
||||||
|
_load_image_as_array=None,
|
||||||
|
_save_bands_as_image=None,
|
||||||
|
_copy_hdr_info=None,
|
||||||
|
_prepare_water_mask_for_algorithm=None,
|
||||||
|
_interpolate_zero_pixels_batch=None,
|
||||||
|
deglint_dir: Union[str, Path] = "./3_deglint",
|
||||||
|
water_mask_dir: Union[str, Path] = "./1_water_mask",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
执行去除耀斑处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path: 输入影像文件路径
|
||||||
|
method: 去耀斑方法
|
||||||
|
...(其余参数同主类 step3_remove_glint)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
去除耀斑后的影像文件路径
|
||||||
|
"""
|
||||||
|
from src.core.glint_removal.Kutser import Kutser
|
||||||
|
from src.core.glint_removal.Goodman import Goodman
|
||||||
|
from src.core.glint_removal.Hedley import Hedley
|
||||||
|
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
|
||||||
|
from src.core.utils.gdal_helper import (
|
||||||
|
get_image_geo_info as _default_get_geo,
|
||||||
|
load_image_as_array as _default_load,
|
||||||
|
save_bands_as_image as _default_save_bands,
|
||||||
|
copy_hdr_info as _default_copy_hdr,
|
||||||
|
)
|
||||||
|
from src.core.utils.mask_converter import (
|
||||||
|
prepare_water_mask_for_algorithm as _default_prepare,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用提供的函数或默认函数
|
||||||
|
if _get_image_geo_info is None:
|
||||||
|
_get_image_geo_info = _default_get_geo
|
||||||
|
if _load_image_as_array is None:
|
||||||
|
_load_image_as_array = _default_load
|
||||||
|
if _save_bands_as_image is None:
|
||||||
|
_save_bands_as_image = _default_save_bands
|
||||||
|
if _copy_hdr_info is None:
|
||||||
|
_copy_hdr_info = _default_copy_hdr
|
||||||
|
if _prepare_water_mask_for_algorithm is None:
|
||||||
|
_prepare_water_mask_for_algorithm = _default_prepare
|
||||||
|
|
||||||
|
deglint_dir = Path(deglint_dir)
|
||||||
|
deglint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤3", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤3: 去除耀斑")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
# 方法名标准化
|
||||||
|
raw_method = str(method).lower()
|
||||||
|
if "kutser" in raw_method:
|
||||||
|
method = "kutser"
|
||||||
|
elif "goodman" in raw_method:
|
||||||
|
method = "goodman"
|
||||||
|
elif "hedley" in raw_method:
|
||||||
|
method = "hedley"
|
||||||
|
elif "sugar" in raw_method:
|
||||||
|
method = "sugar"
|
||||||
|
|
||||||
|
# 如果未启用,直接返回原始影像
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过去除耀斑(enabled=False),将直接使用原始影像。")
|
||||||
|
notify("skipped", "跳过去耀斑,使用原始影像")
|
||||||
|
return img_path
|
||||||
|
|
||||||
|
# ---- 确定水域掩膜 ----
|
||||||
|
final_water_mask = water_mask
|
||||||
|
if final_water_mask is not None and str(final_water_mask).lower().endswith(".shp"):
|
||||||
|
# shp 自动替换为 dat
|
||||||
|
dat_mask = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
|
||||||
|
if Path(dat_mask).exists():
|
||||||
|
print(f"检测到输入掩膜为 .shp,自动替换为栅格掩膜: {dat_mask}")
|
||||||
|
final_water_mask = dat_mask
|
||||||
|
|
||||||
|
if final_water_mask is None:
|
||||||
|
dat_mask_default = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
|
||||||
|
if Path(dat_mask_default).exists():
|
||||||
|
final_water_mask = dat_mask_default
|
||||||
|
print(f"使用步骤1生成的水域掩膜: {final_water_mask}")
|
||||||
|
|
||||||
|
# ---- 步骤3.1: 0值像素插值 ----
|
||||||
|
if interpolate_zeros:
|
||||||
|
print("\n" + "-" * 80)
|
||||||
|
print("步骤3.1: 对0值像素进行插值")
|
||||||
|
print("-" * 80)
|
||||||
|
interp_start_time = time.time()
|
||||||
|
|
||||||
|
if _interpolate_zero_pixels_batch is None:
|
||||||
|
from src.core.algorithms.interpolation.interpolator import (
|
||||||
|
interpolate_zero_pixels_batch as _interp_batch,
|
||||||
|
)
|
||||||
|
_interpolate_zero_pixels_batch = _interp_batch
|
||||||
|
|
||||||
|
interp_result, _ = _interpolate_zero_pixels_batch(
|
||||||
|
img_path=img_path,
|
||||||
|
interpolation_method=interpolation_method,
|
||||||
|
output_path=None,
|
||||||
|
water_mask=final_water_mask,
|
||||||
|
deglint_dir=str(deglint_dir),
|
||||||
|
callback_progress=lambda msg: print(f" {msg}"),
|
||||||
|
)
|
||||||
|
img_path = interp_result
|
||||||
|
interp_end_time = time.time()
|
||||||
|
print(f"插值完成,使用插值后的影像: {img_path}")
|
||||||
|
|
||||||
|
# ---- 获取影像信息 ----
|
||||||
|
geotransform, projection, width, height, n_bands = _get_image_geo_info(img_path)
|
||||||
|
print(f"影像尺寸: {width} x {height} x {n_bands}")
|
||||||
|
|
||||||
|
mask_for_algorithm = _prepare_water_mask_for_algorithm(
|
||||||
|
final_water_mask, (height, width), geotransform, projection, img_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== Kutser ====================
|
||||||
|
if method == "kutser":
|
||||||
|
print(f"使用方法: Kutser (氧吸收波段={oxy_band}, NIR波段={nir_band})")
|
||||||
|
hardcoded_bsq = str(deglint_dir / "deglint_kutser.bsq")
|
||||||
|
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||||
|
# 将用户指定的 output_path 标准化为 .bsq 路径
|
||||||
|
if output_path:
|
||||||
|
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||||
|
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||||
|
else:
|
||||||
|
final_bsq = hardcoded_bsq
|
||||||
|
final_hdr = hardcoded_hdr
|
||||||
|
|
||||||
|
if Path(hardcoded_bsq).exists():
|
||||||
|
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||||
|
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||||
|
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||||
|
|
||||||
|
kutser = Kutser(
|
||||||
|
img_path,
|
||||||
|
shp_path=None,
|
||||||
|
oxy_band=oxy_band,
|
||||||
|
lower_oxy=lower_oxy,
|
||||||
|
upper_oxy=upper_oxy,
|
||||||
|
NIR_band=nir_band,
|
||||||
|
water_mask=mask_for_algorithm,
|
||||||
|
output_path=hardcoded_bsq,
|
||||||
|
)
|
||||||
|
kutser.get_corrected_bands()
|
||||||
|
|
||||||
|
if Path(hardcoded_bsq).exists():
|
||||||
|
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||||
|
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||||
|
notify("completed", f"去耀斑影像已生成: {final}")
|
||||||
|
return final
|
||||||
|
raise RuntimeError(f"Kutser算法未生成输出文件: {hardcoded_bsq}")
|
||||||
|
|
||||||
|
# ==================== Goodman ====================
|
||||||
|
elif method == "goodman":
|
||||||
|
print(f"使用方法: Goodman (NIR波段范围: {nir_lower}-{nir_upper})")
|
||||||
|
hardcoded_bsq = str(deglint_dir / "deglint_goodman.bsq")
|
||||||
|
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||||
|
if output_path:
|
||||||
|
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||||
|
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||||
|
else:
|
||||||
|
final_bsq = hardcoded_bsq
|
||||||
|
final_hdr = hardcoded_hdr
|
||||||
|
|
||||||
|
if Path(hardcoded_bsq).exists():
|
||||||
|
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||||
|
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||||
|
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||||
|
|
||||||
|
goodman = Goodman(
|
||||||
|
img_path,
|
||||||
|
NIR_lower=nir_lower,
|
||||||
|
NIR_upper=nir_upper,
|
||||||
|
A=goodman_A,
|
||||||
|
B=goodman_B,
|
||||||
|
water_mask=mask_for_algorithm,
|
||||||
|
output_path=hardcoded_bsq,
|
||||||
|
)
|
||||||
|
corrected_bands = goodman.get_corrected_bands()
|
||||||
|
|
||||||
|
if not Path(hardcoded_bsq).exists():
|
||||||
|
_save_bands_as_image(corrected_bands, hardcoded_bsq, geotransform, projection)
|
||||||
|
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||||
|
del corrected_bands
|
||||||
|
|
||||||
|
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||||
|
notify("completed", f"去耀斑影像已生成: {final}")
|
||||||
|
return final
|
||||||
|
|
||||||
|
# ==================== Hedley ====================
|
||||||
|
elif method == "hedley":
|
||||||
|
print(f"使用方法: Hedley (NIR波段={hedley_nir_band})")
|
||||||
|
hardcoded_bsq = str(deglint_dir / "deglint_hedley.bsq")
|
||||||
|
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||||
|
if output_path:
|
||||||
|
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||||
|
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||||
|
else:
|
||||||
|
final_bsq = hardcoded_bsq
|
||||||
|
final_hdr = hardcoded_hdr
|
||||||
|
|
||||||
|
if Path(hardcoded_bsq).exists():
|
||||||
|
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||||
|
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||||
|
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||||
|
|
||||||
|
hedley = Hedley(
|
||||||
|
img_path,
|
||||||
|
shp_path=None,
|
||||||
|
NIR_band=hedley_nir_band,
|
||||||
|
water_mask=mask_for_algorithm,
|
||||||
|
output_path=hardcoded_bsq,
|
||||||
|
)
|
||||||
|
hedley.get_corrected_bands()
|
||||||
|
|
||||||
|
if Path(hardcoded_bsq).exists():
|
||||||
|
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||||
|
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||||
|
notify("completed", f"去耀斑影像已生成: {final}")
|
||||||
|
return final
|
||||||
|
raise RuntimeError(f"Hedley算法未生成输出文件: {hardcoded_bsq}")
|
||||||
|
|
||||||
|
# ==================== SUGAR ====================
|
||||||
|
elif method == "sugar":
|
||||||
|
glint_method_raw = str(sugar_glint_mask_method).lower()
|
||||||
|
if "cdf" in glint_method_raw or "累积" in glint_method_raw:
|
||||||
|
sugar_glint_mask_method_fixed = "cdf"
|
||||||
|
elif "otsu" in glint_method_raw or "大津" in glint_method_raw:
|
||||||
|
sugar_glint_mask_method_fixed = "otsu"
|
||||||
|
else:
|
||||||
|
sugar_glint_mask_method_fixed = "cdf"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"使用方法: SUGAR (迭代次数={sugar_iter}, 掩膜方法={sugar_glint_mask_method_fixed})"
|
||||||
|
)
|
||||||
|
hardcoded_bsq = str(deglint_dir / "deglint_sugar.bsq")
|
||||||
|
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||||
|
if output_path:
|
||||||
|
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||||
|
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||||
|
else:
|
||||||
|
final_bsq = hardcoded_bsq
|
||||||
|
final_hdr = hardcoded_hdr
|
||||||
|
|
||||||
|
if Path(hardcoded_bsq).exists():
|
||||||
|
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||||
|
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||||
|
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||||
|
|
||||||
|
if sugar_bounds is None:
|
||||||
|
sugar_bounds = [(1, 2)]
|
||||||
|
|
||||||
|
correction_iterative(
|
||||||
|
img_path,
|
||||||
|
iter=sugar_iter,
|
||||||
|
bounds=sugar_bounds,
|
||||||
|
estimate_background=sugar_estimate_background,
|
||||||
|
glint_mask_method=sugar_glint_mask_method_fixed,
|
||||||
|
termination_thresh=sugar_termination_thresh,
|
||||||
|
water_mask=mask_for_algorithm,
|
||||||
|
output_path=hardcoded_bsq,
|
||||||
|
)
|
||||||
|
|
||||||
|
if Path(hardcoded_bsq).exists():
|
||||||
|
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||||
|
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||||
|
notify("completed", f"去耀斑影像已生成: {final}")
|
||||||
|
return final
|
||||||
|
raise RuntimeError(f"SUGAR算法未生成输出文件: {hardcoded_bsq}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"不支持的方法: {method}。支持的方法: kutser, goodman, hedley, sugar"
|
||||||
|
)
|
||||||
109
src/core/steps/mapping_step.py
Normal file
109
src/core/steps/mapping_step.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
成图步骤
|
||||||
|
|
||||||
|
包含 step9_generate_distribution_map
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union, Callable
|
||||||
|
|
||||||
|
|
||||||
|
class MappingStep:
|
||||||
|
"""成图步骤"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_distribution_map(
|
||||||
|
prediction_csv_path: str,
|
||||||
|
boundary_shp_path: str,
|
||||||
|
output_image_path: Optional[str] = None,
|
||||||
|
resolution: float = 30,
|
||||||
|
input_crs: str = "EPSG:32651",
|
||||||
|
output_crs: str = "EPSG:4326",
|
||||||
|
show_sample_points: bool = False,
|
||||||
|
base_map_tif: Optional[str] = None,
|
||||||
|
use_distance_diffusion: bool = True,
|
||||||
|
max_diffusion_distance: Optional[float] = None,
|
||||||
|
diffusion_power: float = 2,
|
||||||
|
diffusion_n_neighbors: int = 15,
|
||||||
|
cmap: Optional[str] = None,
|
||||||
|
expand_ratio: float = 0.05,
|
||||||
|
output_dir: Union[str, Path] = "./14_visualization",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
根据采样点的坐标和反演的实测参数,通过插值方法得到水质参数可视化分布图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prediction_csv_path: 预测结果CSV文件路径(前两列为经纬度,第三列为预测值)
|
||||||
|
boundary_shp_path: 边界shapefile文件路径
|
||||||
|
output_image_path: 输出图片路径(如果为None,自动生成)
|
||||||
|
resolution: 插值网格分辨率(米)
|
||||||
|
input_crs: 输入坐标系
|
||||||
|
output_crs: 输出坐标系
|
||||||
|
show_sample_points: 是否在图上显示采样点
|
||||||
|
base_map_tif: 底图TIF路径
|
||||||
|
use_distance_diffusion: 是否启用距离扩散补全边界
|
||||||
|
max_diffusion_distance: 距离扩散最大距离(米)
|
||||||
|
diffusion_power: 距离扩散幂参数
|
||||||
|
diffusion_n_neighbors: 距离扩散最近邻数量
|
||||||
|
cmap: 颜色映射名称(None表示自动识别)
|
||||||
|
expand_ratio: 边界外扩比例(0-1之间)
|
||||||
|
output_dir: 输出目录
|
||||||
|
callback: 回调函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
可视化分布图文件路径
|
||||||
|
"""
|
||||||
|
from src.postprocessing.map import ContentMapper
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤9", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤9: 生成水质参数可视化分布图")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if output_image_path is None:
|
||||||
|
csv_name = Path(prediction_csv_path).stem
|
||||||
|
output_image_path = str(output_dir / f"{csv_name}_distribution.png")
|
||||||
|
|
||||||
|
if Path(output_image_path).exists():
|
||||||
|
print(f"检测到已存在的分布图文件,直接使用: {output_image_path}")
|
||||||
|
notify("skipped", f"可视化分布图已设置: {output_image_path}")
|
||||||
|
return output_image_path
|
||||||
|
|
||||||
|
mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs)
|
||||||
|
|
||||||
|
mapper_kwargs = {
|
||||||
|
"resolution": resolution,
|
||||||
|
"show_sample_points": show_sample_points,
|
||||||
|
"use_distance_diffusion": use_distance_diffusion,
|
||||||
|
"diffusion_power": diffusion_power,
|
||||||
|
"diffusion_n_neighbors": diffusion_n_neighbors,
|
||||||
|
"expand_ratio": expand_ratio,
|
||||||
|
}
|
||||||
|
|
||||||
|
optional_kwargs = {
|
||||||
|
"base_map_tif": base_map_tif,
|
||||||
|
"max_diffusion_distance": max_diffusion_distance,
|
||||||
|
"cmap": cmap,
|
||||||
|
}
|
||||||
|
mapper_kwargs.update({k: v for k, v in optional_kwargs.items() if v is not None})
|
||||||
|
|
||||||
|
mapper.process_data(
|
||||||
|
csv_file=prediction_csv_path,
|
||||||
|
shp_file=boundary_shp_path,
|
||||||
|
output_file=output_image_path,
|
||||||
|
**mapper_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"可视化分布图已保存: {output_image_path}")
|
||||||
|
return output_image_path
|
||||||
497
src/core/steps/modeling_step.py
Normal file
497
src/core/steps/modeling_step.py
Normal file
@ -0,0 +1,497 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
建模步骤
|
||||||
|
|
||||||
|
包含 step6_train_models, step6_5_non_empirical_modeling, step6_75_custom_regression
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Union, Callable, Dict
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 汉化 -> 英文 反向映射字典(UI 复选框显示文本 -> 底层算法键名)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 模型名称:中文 (缩写) -> 英文键名
|
||||||
|
MODEL_NAME_MAP = {
|
||||||
|
"多元线性回归 (MLR)": "LinearRegression",
|
||||||
|
"岭回归 (Ridge)": "Ridge",
|
||||||
|
"套索回归 (Lasso)": "Lasso",
|
||||||
|
"弹性网络 (ElasticNet)": "ElasticNet",
|
||||||
|
"偏最小二乘 (PLSR)": "PLS",
|
||||||
|
"决策树 (CART)": "DecisionTree",
|
||||||
|
"随机森林 (RF)": "RF",
|
||||||
|
"极端随机树 (ET)": "ExtraTrees",
|
||||||
|
"极值梯度提升 (XGBoost)": "XGBoost",
|
||||||
|
"轻量梯度提升 (LightGBM)": "LightGBM",
|
||||||
|
"类别梯度提升 (CatBoost)": "CatBoost",
|
||||||
|
"梯度提升树 (GBDT)": "GradientBoosting",
|
||||||
|
"自适应提升 (AdaBoost)": "AdaBoost",
|
||||||
|
"支持向量回归 (SVR)": "SVR",
|
||||||
|
"K近邻回归 (KNN)": "KNN",
|
||||||
|
"多层感知机 (BP神经网络)": "MLP",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 预处理方法:各种可能的中文变体 -> 标准键名
|
||||||
|
PREPROC_NAME_MAP = {
|
||||||
|
# 无处理
|
||||||
|
"无 (None)": "None",
|
||||||
|
"None": "None",
|
||||||
|
# MMS
|
||||||
|
"最小-最大归一化 (MMS)": "MMS",
|
||||||
|
"MMS": "MMS",
|
||||||
|
# SS
|
||||||
|
"标度化 (SS)": "SS",
|
||||||
|
"SS": "SS",
|
||||||
|
# SNV
|
||||||
|
"标准正态变换 (SNV)": "SNV",
|
||||||
|
"SNV": "SNV",
|
||||||
|
# MA
|
||||||
|
"移动平均 (MA)": "MA",
|
||||||
|
"MA": "MA",
|
||||||
|
# SG
|
||||||
|
"Savitzky-Golay (SG)": "SG",
|
||||||
|
"SG": "SG",
|
||||||
|
# MSC
|
||||||
|
"多元散射校正 (MSC)": "MSC",
|
||||||
|
"MSC": "MSC",
|
||||||
|
# D1
|
||||||
|
"一阶导数 (D1)": "D1",
|
||||||
|
"D1": "D1",
|
||||||
|
# D2
|
||||||
|
"二阶导数 (D2)": "D2",
|
||||||
|
"D2": "D2",
|
||||||
|
# DT
|
||||||
|
"去趋势 (DT)": "DT",
|
||||||
|
"DT": "DT",
|
||||||
|
# CT
|
||||||
|
"中心化 (CT)": "CT",
|
||||||
|
"CT": "CT",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 数据划分方法:各种可能的中文变体 -> 标准键名
|
||||||
|
SPLIT_NAME_MAP = {
|
||||||
|
"SPXY 算法 (考量X-Y空间)": "spxy",
|
||||||
|
"spxy": "spxy",
|
||||||
|
"KS 算法 (考量X空间)": "ks",
|
||||||
|
"ks": "ks",
|
||||||
|
"随机划分 (Random)": "random",
|
||||||
|
"random": "random",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_model_names(model_names: List[str]) -> List[str]:
|
||||||
|
"""清洗模型名称列表:将汉化显示文本还原为英文键名"""
|
||||||
|
result = []
|
||||||
|
for name in model_names:
|
||||||
|
if name in MODEL_NAME_MAP:
|
||||||
|
result.append(MODEL_NAME_MAP[name])
|
||||||
|
else:
|
||||||
|
# 已经是英文键名,直接保留
|
||||||
|
result.append(name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_preprocessing_methods(methods: List[str]) -> List[str]:
|
||||||
|
"""清洗预处理方法列表:将汉化显示文本还原为标准键名"""
|
||||||
|
result = []
|
||||||
|
for method in methods:
|
||||||
|
if method in PREPROC_NAME_MAP:
|
||||||
|
result.append(PREPROC_NAME_MAP[method])
|
||||||
|
else:
|
||||||
|
# 已经是标准键名,直接保留
|
||||||
|
result.append(method)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_split_methods(methods: List[str]) -> List[str]:
|
||||||
|
"""清洗数据划分方法列表:将汉化显示文本还原为标准键名"""
|
||||||
|
result = []
|
||||||
|
for method in methods:
|
||||||
|
if method in SPLIT_NAME_MAP:
|
||||||
|
result.append(SPLIT_NAME_MAP[method])
|
||||||
|
else:
|
||||||
|
# 已经是标准键名,直接保留
|
||||||
|
result.append(method)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ModelingStep:
|
||||||
|
"""建模步骤"""
|
||||||
|
|
||||||
|
# ---- Step 6: 训练机器学习模型 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def train_models(
|
||||||
|
feature_start_column: str = "374.285004",
|
||||||
|
preprocessing_methods: Optional[List[str]] = None,
|
||||||
|
model_names: Optional[List[str]] = None,
|
||||||
|
split_methods: Optional[List[str]] = None,
|
||||||
|
cv_folds: int = 5,
|
||||||
|
training_csv_path: Optional[str] = None,
|
||||||
|
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
_report_generator=None,
|
||||||
|
) -> str:
|
||||||
|
"""使用采样点光谱和实测值建立机器学习模型"""
|
||||||
|
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤6", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤6: 训练机器学习模型")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if training_csv_path is None:
|
||||||
|
raise ValueError("必须提供 training_csv_path 参数")
|
||||||
|
|
||||||
|
# 检查模型目录是否已有模型
|
||||||
|
if output_dir.exists() and any(output_dir.iterdir()):
|
||||||
|
has_models = False
|
||||||
|
for item in output_dir.iterdir():
|
||||||
|
if item.is_dir():
|
||||||
|
model_files = (
|
||||||
|
list(item.glob("*.pkl"))
|
||||||
|
+ list(item.glob("*.joblib"))
|
||||||
|
+ list(item.glob("*.h5"))
|
||||||
|
)
|
||||||
|
if model_files:
|
||||||
|
has_models = True
|
||||||
|
break
|
||||||
|
if has_models:
|
||||||
|
print(f"检测到已存在的模型文件,直接使用: {output_dir}")
|
||||||
|
notify("skipped", f"模型目录已设置: {output_dir}")
|
||||||
|
return str(output_dir)
|
||||||
|
|
||||||
|
if preprocessing_methods is None:
|
||||||
|
preprocessing_methods = ["None", "MMS", "SS", "SNV", "MA", "SG", "MSC", "D1", "D2", "DT", "CT"]
|
||||||
|
if model_names is None:
|
||||||
|
model_names = ["SVR", "RF", "Ridge", "Lasso"]
|
||||||
|
if split_methods is None:
|
||||||
|
split_methods = ["spxy", "ks", "random"]
|
||||||
|
|
||||||
|
# ---- 汉化清洗:将 UI 传来的中文/混合名称转换为底层英文键名 ----
|
||||||
|
preprocessing_methods = _normalize_preprocessing_methods(preprocessing_methods)
|
||||||
|
model_names = _normalize_model_names(model_names)
|
||||||
|
split_methods = _normalize_split_methods(split_methods)
|
||||||
|
|
||||||
|
print(f"[参数清洗] 预处理方法: {preprocessing_methods}")
|
||||||
|
print(f"[参数清洗] 模型名称: {model_names}")
|
||||||
|
print(f"[参数清洗] 划分方法: {split_methods}")
|
||||||
|
|
||||||
|
modeler = WaterQualityModelingBatch(str(output_dir))
|
||||||
|
modeler.train_models_batch(
|
||||||
|
csv_path=training_csv_path,
|
||||||
|
feature_start_column=feature_start_column,
|
||||||
|
preprocessing_methods=preprocessing_methods,
|
||||||
|
model_names=model_names,
|
||||||
|
split_methods=split_methods,
|
||||||
|
cv_folds=cv_folds,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"模型训练完成,结果保存在: {output_dir}")
|
||||||
|
|
||||||
|
if _report_generator is not None:
|
||||||
|
try:
|
||||||
|
summary_path = _report_generator.generate_training_summary(str(output_dir))
|
||||||
|
print(f"训练摘要报告已生成: {summary_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"生成训练摘要报告时出错: {e}")
|
||||||
|
|
||||||
|
notify("completed", f"模型训练完成: {output_dir}")
|
||||||
|
return str(output_dir)
|
||||||
|
|
||||||
|
# ---- Step 6.5: 非经验统计回归模型训练 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def train_non_empirical_models(
|
||||||
|
csv_path: Optional[str] = None,
|
||||||
|
preprocessing_methods: Optional[List[str]] = None,
|
||||||
|
algorithms: Optional[List[str]] = None,
|
||||||
|
value_cols: Union[int, Dict[str, int]] = 0,
|
||||||
|
spectral_start_col: int = 1,
|
||||||
|
spectral_end_col: Optional[int] = None,
|
||||||
|
window: int = 5,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""非经验统计回归模型训练"""
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤6.5", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤6.5: 非经验统计回归模型训练")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过非经验模型训练(enabled=False)。")
|
||||||
|
notify("skipped", "跳过的经验模型训练")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if csv_path is None:
|
||||||
|
raise ValueError("必须提供 csv_path 参数")
|
||||||
|
|
||||||
|
if output_dir is not None:
|
||||||
|
non_empirical_dir = Path(output_dir)
|
||||||
|
else:
|
||||||
|
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
|
||||||
|
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if preprocessing_methods is None:
|
||||||
|
preprocessing_methods = ["None"]
|
||||||
|
if algorithms is None:
|
||||||
|
algorithms = ["chl_a", "nh3", "mno4", "tn", "tp", "tss"]
|
||||||
|
|
||||||
|
if isinstance(value_cols, int):
|
||||||
|
value_cols_dict = {algorithm: value_cols for algorithm in algorithms}
|
||||||
|
elif isinstance(value_cols, dict):
|
||||||
|
value_cols_dict = value_cols
|
||||||
|
else:
|
||||||
|
raise ValueError("value_cols 参数必须是整数或字典")
|
||||||
|
|
||||||
|
if spectral_end_col is None:
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
spectral_end_col = len(df.columns) - 1
|
||||||
|
|
||||||
|
all_model_results = {}
|
||||||
|
|
||||||
|
for preprocess in preprocessing_methods:
|
||||||
|
preprocess_dir = non_empirical_dir / preprocess
|
||||||
|
preprocess_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
processed_csv_path = _apply_preprocessing_internal(
|
||||||
|
csv_path, preprocess, preprocess_dir, spectral_start_col
|
||||||
|
)
|
||||||
|
|
||||||
|
for algorithm in algorithms:
|
||||||
|
algorithm_value_col = value_cols_dict[algorithm]
|
||||||
|
print(f"\n训练 {preprocess} + {algorithm} 模型 (实测值列: {algorithm_value_col})...")
|
||||||
|
|
||||||
|
model_outpath = str(preprocess_dir / f"{preprocess}_{algorithm}.json")
|
||||||
|
|
||||||
|
if Path(model_outpath).exists():
|
||||||
|
print(f"检测到已存在的模型文件,直接使用: {model_outpath}")
|
||||||
|
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.core.non_empirical_model_correction import run_model_correction
|
||||||
|
run_model_correction(
|
||||||
|
algorithm=algorithm,
|
||||||
|
csv_file=processed_csv_path if Path(processed_csv_path).exists() else csv_path,
|
||||||
|
value_col=algorithm_value_col,
|
||||||
|
spectral_start=spectral_start_col,
|
||||||
|
spectral_end=spectral_end_col,
|
||||||
|
model_info_outpath=model_outpath,
|
||||||
|
window=window,
|
||||||
|
)
|
||||||
|
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||||
|
print(f"模型训练完成: {model_outpath}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"训练 {preprocess}_{algorithm} 模型时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
summary_path = _generate_non_empirical_summary(all_model_results, non_empirical_dir)
|
||||||
|
notify("completed", f"非经验模型训练完成: {non_empirical_dir}")
|
||||||
|
return all_model_results
|
||||||
|
|
||||||
|
# ---- Step 6.75: 自定义回归分析 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def custom_regression(
|
||||||
|
csv_path: Optional[str] = None,
|
||||||
|
x_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
y_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
methods: Union[str, List[str]] = "all",
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
work_dir: Union[str, Path] = "./work_dir",
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""使用自定义回归方法分析指标与目标参数之间的关系"""
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤6.75", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤6.75: 自定义回归分析")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过自定义回归分析(enabled=False)。")
|
||||||
|
notify("skipped", "跳过自定义回归分析")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if csv_path is None:
|
||||||
|
raise ValueError("必须提供 csv_path 参数")
|
||||||
|
if y_columns is None:
|
||||||
|
raise ValueError("必须指定 y_columns")
|
||||||
|
if x_columns is None:
|
||||||
|
raise ValueError("必须指定 x_columns")
|
||||||
|
|
||||||
|
if isinstance(x_columns, str):
|
||||||
|
x_columns = [x_columns]
|
||||||
|
if isinstance(y_columns, str):
|
||||||
|
y_columns = [y_columns]
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
missing_x = [col for col in x_columns if col not in df.columns]
|
||||||
|
missing_y = [col for col in y_columns if col not in df.columns]
|
||||||
|
if missing_x:
|
||||||
|
raise ValueError(f"自变量列不存在: {missing_x}")
|
||||||
|
if missing_y:
|
||||||
|
raise ValueError(f"因变量列不存在: {missing_y}")
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
custom_regression_dir = Path(work_dir) / "9_Custom_Regression_Modeling"
|
||||||
|
else:
|
||||||
|
custom_regression_dir = Path(work_dir) / output_dir
|
||||||
|
custom_regression_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
from src.core.modeling.regression import SingleVariableRegressionAnalysis
|
||||||
|
analyzer = SingleVariableRegressionAnalysis()
|
||||||
|
analyzer.batch_single_variable_regression(
|
||||||
|
data=df,
|
||||||
|
x_columns=x_columns,
|
||||||
|
y_columns=y_columns,
|
||||||
|
methods=methods,
|
||||||
|
output_dir=str(custom_regression_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"自定义回归结果已保存到目录: {custom_regression_dir}")
|
||||||
|
return str(custom_regression_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 内部辅助函数(供 ModelingStep 内部使用)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def _apply_preprocessing_internal(
|
||||||
|
csv_path: str,
|
||||||
|
preprocess_method: str,
|
||||||
|
output_dir: Path,
|
||||||
|
spectral_start_col: int = 4,
|
||||||
|
) -> str:
|
||||||
|
"""应用预处理到CSV数据(内部函数)"""
|
||||||
|
raw_p = str(preprocess_method).lower()
|
||||||
|
if raw_p == "none" or "无" in raw_p or "跳过" in raw_p:
|
||||||
|
preprocess_method = "None"
|
||||||
|
elif raw_p == "mms" or "minmax" in raw_p or "最大最小" in raw_p:
|
||||||
|
preprocess_method = "MMS"
|
||||||
|
elif raw_p == "ss" or "标准" in raw_p or "标准化" in raw_p:
|
||||||
|
preprocess_method = "SS"
|
||||||
|
elif raw_p == "snv" or "标准正态" in raw_p:
|
||||||
|
preprocess_method = "SNV"
|
||||||
|
elif raw_p == "ma" or "移动" in raw_p:
|
||||||
|
preprocess_method = "MA"
|
||||||
|
elif raw_p == "sg" or "savitzky" in raw_p or "平滑" in raw_p:
|
||||||
|
preprocess_method = "SG"
|
||||||
|
elif raw_p == "msc" or "多元散射" in raw_p:
|
||||||
|
preprocess_method = "MSC"
|
||||||
|
elif raw_p in ("d1", "d2", "dt"):
|
||||||
|
preprocess_method = {"d1": "D1", "d2": "D2", "dt": "DT"}.get(raw_p, raw_p.upper())
|
||||||
|
elif raw_p == "ct" or "去趋势" in raw_p:
|
||||||
|
preprocess_method = "CT"
|
||||||
|
|
||||||
|
if preprocess_method == "None":
|
||||||
|
return csv_path
|
||||||
|
|
||||||
|
output_filename = f"preprocessed_{preprocess_method}.csv"
|
||||||
|
output_path = str(output_dir / output_filename)
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的预处理文件,直接使用: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
non_spectral_cols = df.iloc[:, :spectral_start_col]
|
||||||
|
spectral_data = df.iloc[:, spectral_start_col:]
|
||||||
|
|
||||||
|
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||||
|
|
||||||
|
save_path = None
|
||||||
|
if preprocess_method == "SS":
|
||||||
|
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
|
||||||
|
models_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_path = str(models_dir / "scaler_params.pkl")
|
||||||
|
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||||
|
|
||||||
|
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
|
||||||
|
|
||||||
|
if isinstance(processed_spectral, pd.DataFrame):
|
||||||
|
processed_df = pd.concat([non_spectral_cols, processed_spectral], axis=1)
|
||||||
|
else:
|
||||||
|
processed_spectral_df = pd.DataFrame(
|
||||||
|
processed_spectral, columns=spectral_data.columns, index=spectral_data.index
|
||||||
|
)
|
||||||
|
processed_df = pd.concat([non_spectral_cols, processed_spectral_df], axis=1)
|
||||||
|
|
||||||
|
processed_df.to_csv(output_path, index=False)
|
||||||
|
print(f"预处理完成: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_non_empirical_summary(model_results: Dict[str, str], output_dir: Path) -> str:
|
||||||
|
"""生成非经验模型训练结果汇总CSV"""
|
||||||
|
summary_path = str(output_dir / "non_empirical_models_summary.csv")
|
||||||
|
summary_data = []
|
||||||
|
|
||||||
|
for model_key, model_path in model_results.items():
|
||||||
|
try:
|
||||||
|
parts = model_key.split("_")
|
||||||
|
preprocess_method = parts[0]
|
||||||
|
algorithm_name = "_".join(parts[1:]) if len(parts) > 2 else parts[1]
|
||||||
|
|
||||||
|
with open(model_path, "r", encoding="utf-8") as f:
|
||||||
|
model_info = json.load(f)
|
||||||
|
|
||||||
|
accuracy_list = model_info.get("accuracy", [])
|
||||||
|
summary_row = {
|
||||||
|
"Preprocessing Method": preprocess_method,
|
||||||
|
"Algorithm Name": algorithm_name,
|
||||||
|
"Model Type": model_info.get("model_type", ""),
|
||||||
|
"Coefficient Count": len(model_info.get("model_info", [])),
|
||||||
|
"Average Accuracy(%)": np.mean(accuracy_list) if accuracy_list else 0,
|
||||||
|
"Min Accuracy(%)": np.min(accuracy_list) if accuracy_list else 0,
|
||||||
|
"Max Accuracy(%)": np.max(accuracy_list) if accuracy_list else 0,
|
||||||
|
"Sample Count": len(model_info.get("long", [])),
|
||||||
|
"Model File": model_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
coefficients = model_info.get("model_info", [])
|
||||||
|
for i, coeff in enumerate(coefficients[:5]):
|
||||||
|
summary_row[f"系数_{i+1}"] = coeff
|
||||||
|
|
||||||
|
summary_data.append(summary_row)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取模型文件 {model_path} 时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if summary_data:
|
||||||
|
df_summary = pd.DataFrame(summary_data)
|
||||||
|
df_summary.to_csv(summary_path, index=False, encoding="utf-8-sig")
|
||||||
|
print(f"汇总文件已生成: {summary_path}")
|
||||||
|
else:
|
||||||
|
print("警告: 没有有效的模型数据可汇总")
|
||||||
|
summary_path = ""
|
||||||
|
|
||||||
|
return summary_path
|
||||||
350
src/core/steps/prediction_step.py
Normal file
350
src/core/steps/prediction_step.py
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
预测步骤
|
||||||
|
|
||||||
|
包含 step7_generate_sampling_points, step8_predict_water_quality,
|
||||||
|
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Union, Callable, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionStep:
|
||||||
|
"""预测步骤"""
|
||||||
|
|
||||||
|
# ---- Step 7: 生成采样点并提取光谱 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_sampling_points(
|
||||||
|
deglint_img_path: Optional[str] = None,
|
||||||
|
interval: int = 50,
|
||||||
|
sample_radius: int = 5,
|
||||||
|
chunk_size: int = 1000,
|
||||||
|
water_mask_path: Optional[str] = None,
|
||||||
|
glint_mask_path: Optional[str] = None,
|
||||||
|
output_dir: Union[str, Path] = "./10_sampling",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
|
||||||
|
from pathlib import Path
|
||||||
|
from src.utils.sampling import get_spectral_sampling_points_chunked
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = str(output_dir / "sampling_spectra.csv")
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤7", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤7: 生成预测采样点并提取光谱")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if deglint_img_path is None:
|
||||||
|
raise ValueError("必须提供 deglint_img_path 参数")
|
||||||
|
|
||||||
|
# 1. 初始归一化与安全转换
|
||||||
|
original_path = Path(deglint_img_path)
|
||||||
|
final_deglint_path = original_path
|
||||||
|
|
||||||
|
# 2. 智能回溯探测:如果当前路径不存在,或者后缀是前端死板的 .dat
|
||||||
|
if not final_deglint_path.exists() or final_deglint_path.suffix.lower() == '.dat':
|
||||||
|
print(f"🔍 智能探测:输入去耀斑路径不存在或为 .dat 占位符 ({final_deglint_path}),正在向上搜索真实产物...")
|
||||||
|
|
||||||
|
# 定位到预期的 3_deglint 根目录
|
||||||
|
possible_dir = original_path.parent
|
||||||
|
if possible_dir.name != '3_deglint' and Path(output_path).parent.parent.exists():
|
||||||
|
possible_dir = Path(output_path).parent.parent / "3_deglint"
|
||||||
|
|
||||||
|
if possible_dir.exists():
|
||||||
|
# 搜寻该目录下所有真实存在的 .bsq 文件(接管 goodman/sugar/kutser/hedley 的硬编码产物)
|
||||||
|
existing_bsqs = list(possible_dir.glob("*.bsq"))
|
||||||
|
if existing_bsqs:
|
||||||
|
final_deglint_path = existing_bsqs[0]
|
||||||
|
print(f"💡 智能拦截成功:自动寻回底层真实去耀斑影像: {final_deglint_path}")
|
||||||
|
else:
|
||||||
|
final_deglint_path = original_path.with_suffix('.bsq')
|
||||||
|
else:
|
||||||
|
final_deglint_path = original_path.with_suffix('.bsq')
|
||||||
|
|
||||||
|
deglint_img_str = str(final_deglint_path)
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"采样点光谱数据已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
glint_mask_to_use = glint_mask_path
|
||||||
|
if glint_mask_to_use is None:
|
||||||
|
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
|
||||||
|
|
||||||
|
# 传递极度安全的 deglint_img_str 进底层
|
||||||
|
get_spectral_sampling_points_chunked(
|
||||||
|
deglint_img_str, water_mask_path, glint_mask_to_use,
|
||||||
|
output_path, interval, sample_radius, chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"采样点光谱数据已保存: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# ---- Step 8: 机器学习模型预测水质参数 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_water_quality(
|
||||||
|
sampling_csv_path: str,
|
||||||
|
models_dir: Optional[str] = None,
|
||||||
|
metric: str = "test_r2",
|
||||||
|
prediction_column: str = "prediction",
|
||||||
|
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
_report_generator=None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||||||
|
from src.core.prediction.inference_batch import WaterQualityInference
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤8", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤8: 预测水质参数")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if models_dir is None:
|
||||||
|
raise ValueError("必须提供 models_dir 参数")
|
||||||
|
|
||||||
|
ml_prediction_dir = Path(output_dir)
|
||||||
|
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
prediction_files = {}
|
||||||
|
if ml_prediction_dir.exists():
|
||||||
|
csv_files = list(ml_prediction_dir.glob("*.csv"))
|
||||||
|
for csv_file in csv_files:
|
||||||
|
file_stem = csv_file.stem
|
||||||
|
if "_prediction" in file_stem:
|
||||||
|
target_name = file_stem.replace("_prediction", "")
|
||||||
|
elif "_pred" in file_stem:
|
||||||
|
target_name = file_stem.replace("_pred", "")
|
||||||
|
else:
|
||||||
|
target_name = file_stem
|
||||||
|
prediction_files[target_name] = str(csv_file)
|
||||||
|
|
||||||
|
# 检查是否所有目标参数都有预测文件
|
||||||
|
if prediction_files:
|
||||||
|
models_path_obj = Path(models_dir)
|
||||||
|
if models_path_obj.exists():
|
||||||
|
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
|
||||||
|
missing_targets = [t for t in target_folders if t not in prediction_files]
|
||||||
|
if not missing_targets:
|
||||||
|
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
|
||||||
|
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
|
||||||
|
return prediction_files
|
||||||
|
else:
|
||||||
|
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||||||
|
|
||||||
|
inferencer = WaterQualityInference(models_dir)
|
||||||
|
all_results = inferencer.batch_inference_multi_models(
|
||||||
|
models_root_dir=models_dir,
|
||||||
|
sampling_csv_path=sampling_csv_path,
|
||||||
|
output_dir=str(ml_prediction_dir),
|
||||||
|
metric=metric,
|
||||||
|
prediction_column=prediction_column,
|
||||||
|
output_format="csv",
|
||||||
|
)
|
||||||
|
|
||||||
|
for target_name, result in all_results.items():
|
||||||
|
if result.get("status") == "success":
|
||||||
|
prediction_files[target_name] = result["output_file"]
|
||||||
|
|
||||||
|
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||||||
|
|
||||||
|
if _report_generator is not None:
|
||||||
|
try:
|
||||||
|
report_path = _report_generator.generate_prediction_report(prediction_files)
|
||||||
|
print(f"预测结果报告已生成: {report_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"生成预测结果报告时出错: {e}")
|
||||||
|
|
||||||
|
notify("completed", f"预测完成: {ml_prediction_dir}")
|
||||||
|
return prediction_files
|
||||||
|
|
||||||
|
# ---- Step 8.5: 非经验模型预测 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_with_non_empirical_models(
|
||||||
|
sampling_csv_path: str,
|
||||||
|
non_empirical_models_dir: Optional[str] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
metric: str = "Average Accuracy(%)",
|
||||||
|
prediction_column: str = "prediction",
|
||||||
|
enabled: bool = True,
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
work_dir: Union[str, Path] = "./work_dir",
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""使用非经验统计回归模型进行参数预测"""
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤8.5", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤8.5: 使用非经验模型进行参数预测")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过非经验模型预测(enabled=False)。")
|
||||||
|
notify("skipped", "跳过非经验模型预测")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if non_empirical_models_dir is not None:
|
||||||
|
final_models_dir = non_empirical_models_dir
|
||||||
|
else:
|
||||||
|
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
|
||||||
|
if Path(default_models_dir).exists():
|
||||||
|
final_models_dir = default_models_dir
|
||||||
|
else:
|
||||||
|
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
|
||||||
|
|
||||||
|
if output_dir is not None:
|
||||||
|
non_empirical_prediction_dir = Path(output_dir)
|
||||||
|
else:
|
||||||
|
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
|
||||||
|
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
prediction_files = {}
|
||||||
|
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
|
||||||
|
if not summary_path.exists():
|
||||||
|
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
df_summary = pd.read_csv(summary_path)
|
||||||
|
|
||||||
|
best_models = {}
|
||||||
|
for algorithm in df_summary["Algorithm Name"].unique():
|
||||||
|
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
|
||||||
|
if metric in algorithm_df.columns:
|
||||||
|
best_model_row = algorithm_df.nlargest(1, metric)
|
||||||
|
else:
|
||||||
|
best_model_row = algorithm_df.iloc[[0]]
|
||||||
|
|
||||||
|
best_model_path = best_model_row["Model File"].values[0]
|
||||||
|
best_preprocess = best_model_row["Preprocessing Method"].values[0]
|
||||||
|
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
|
||||||
|
|
||||||
|
best_models[algorithm] = {
|
||||||
|
"model_path": best_model_path,
|
||||||
|
"preprocess_method": best_preprocess,
|
||||||
|
"accuracy": best_accuracy,
|
||||||
|
}
|
||||||
|
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
|
||||||
|
|
||||||
|
pd.read_csv(sampling_csv_path) # just to validate
|
||||||
|
|
||||||
|
for algorithm, model_info in best_models.items():
|
||||||
|
print(f"\n使用 {algorithm} 算法进行预测...")
|
||||||
|
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
|
||||||
|
prediction_files[algorithm] = output_path
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.core.non_empirical_retrieval import non_empirical_retrieval
|
||||||
|
non_empirical_retrieval(
|
||||||
|
algorithm=algorithm,
|
||||||
|
model_info_path=model_info["model_path"],
|
||||||
|
coor_spectral_path=sampling_csv_path,
|
||||||
|
output_path=output_path,
|
||||||
|
wave_radius=5,
|
||||||
|
)
|
||||||
|
prediction_files[algorithm] = output_path
|
||||||
|
print(f"预测完成: {output_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"使用 {algorithm} 算法预测时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
|
||||||
|
return prediction_files
|
||||||
|
|
||||||
|
# ---- Step 8.75: 自定义回归模型预测 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_with_custom_regression(
|
||||||
|
sampling_csv_path: str,
|
||||||
|
custom_regression_dir: Optional[str] = None,
|
||||||
|
formula_csv_path: Optional[str] = None,
|
||||||
|
coordinate_columns: Optional[List[str]] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
filename_prefix: str = "custom_regression_prediction",
|
||||||
|
enabled: bool = True,
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
work_dir: Union[str, Path] = "./work_dir",
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""使用自定义回归模型进行参数预测"""
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤8.75", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤8.75: 使用自定义回归模型进行参数预测")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过自定义回归模型预测(enabled=False)。")
|
||||||
|
notify("skipped", "跳过自定义回归预测")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not Path(sampling_csv_path).exists():
|
||||||
|
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
|
||||||
|
|
||||||
|
if custom_regression_dir is not None:
|
||||||
|
final_regression_dir = custom_regression_dir
|
||||||
|
else:
|
||||||
|
final_regression_dir = str(Path(work_dir) / "9_Custom_Regression_Modeling")
|
||||||
|
if not Path(final_regression_dir).exists():
|
||||||
|
raise ValueError(
|
||||||
|
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
custom_regression_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Custom_Regression_Prediction"
|
||||||
|
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
prediction_output_dir = str(custom_regression_prediction_dir)
|
||||||
|
else:
|
||||||
|
prediction_output_dir = output_dir
|
||||||
|
|
||||||
|
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
|
||||||
|
|
||||||
|
predictor = CustomRegressionPredictor(
|
||||||
|
regression_csv_dir=final_regression_dir,
|
||||||
|
formula_csv_path=formula_csv_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"开始使用自定义回归模块进行批量预测...")
|
||||||
|
print(f" 采样点数据: {sampling_csv_path}")
|
||||||
|
print(f" 回归模型目录: {final_regression_dir}")
|
||||||
|
print(f" 输出目录: {prediction_output_dir}")
|
||||||
|
|
||||||
|
saved_files = predictor.run_batch_prediction(
|
||||||
|
input_csv_path=sampling_csv_path,
|
||||||
|
coordinate_columns=coordinate_columns,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
|
||||||
|
for param_name, filepath in saved_files.items():
|
||||||
|
print(f" {param_name}: {filepath}")
|
||||||
|
|
||||||
|
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
|
||||||
|
return saved_files
|
||||||
148
src/core/steps/water_mask_step.py
Normal file
148
src/core/steps/water_mask_step.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
步骤1: 水域掩膜生成
|
||||||
|
|
||||||
|
支持三种方式:
|
||||||
|
1. 基于 shp 文件栅格化
|
||||||
|
2. 使用现有栅格格式掩膜文件 (.dat/.tif)
|
||||||
|
3. 基于 NDWI 从影像自动生成水体掩膜
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Callable, Union
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class WaterMaskStep:
|
||||||
|
"""水域掩膜生成步骤"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run(
|
||||||
|
mask_path: Optional[str] = None,
|
||||||
|
img_path: Optional[str] = None,
|
||||||
|
ndwi_threshold: float = 0.4,
|
||||||
|
use_ndwi: bool = False,
|
||||||
|
generate_png: bool = True,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
water_mask_dir: Union[str, Path] = "./1_water_mask",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
执行水域掩膜生成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_path: 水体掩膜文件路径,支持 .shp(需 img_path)或 .dat/.tif(直接使用)
|
||||||
|
img_path: 输入影像文件路径(当 mask_path 为 shp 或 use_ndwi=True 时必须提供)
|
||||||
|
ndwi_threshold: NDWI 阈值(use_ndwi=True 时使用)
|
||||||
|
use_ndwi: 是否使用 NDWI 方法从影像生成水体掩膜
|
||||||
|
generate_png: 是否生成 PNG 预览图(默认 True)
|
||||||
|
output_path: 指定输出掩膜文件的保存路径(可选)
|
||||||
|
water_mask_dir: 工作目录
|
||||||
|
callback: 回调函数,签名为 callback(step, status, message)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dat 格式的水域掩膜文件路径
|
||||||
|
"""
|
||||||
|
from src.utils.extract_water_area import rasterize_shp, ndwi
|
||||||
|
from src.core.utils.preview_generator import (
|
||||||
|
generate_image_preview,
|
||||||
|
generate_water_mask_overlay,
|
||||||
|
)
|
||||||
|
|
||||||
|
water_mask_dir = Path(water_mask_dir)
|
||||||
|
water_mask_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤1", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤1: 生成或设置水域mask")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
# 生成影像预览图
|
||||||
|
if generate_png and img_path is not None and Path(img_path).exists():
|
||||||
|
preview_path = str(water_mask_dir / "hsi_preview.png")
|
||||||
|
generate_image_preview(
|
||||||
|
img_path=img_path,
|
||||||
|
output_path=preview_path,
|
||||||
|
title="影像预览: RGB波段(基于波长)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- NDWI 方法 ----
|
||||||
|
if use_ndwi:
|
||||||
|
if img_path is None:
|
||||||
|
raise ValueError("当 use_ndwi=True 时,必须提供 img_path 参数")
|
||||||
|
if not Path(img_path).exists():
|
||||||
|
raise ValueError(f"影像文件不存在: {img_path}")
|
||||||
|
|
||||||
|
print(f"使用NDWI方法从影像生成水体掩膜,阈值={ndwi_threshold}...")
|
||||||
|
|
||||||
|
ndwi_output_path = output_path or str(water_mask_dir / "water_mask_from_ndwi.dat")
|
||||||
|
os.makedirs(Path(ndwi_output_path).parent, exist_ok=True)
|
||||||
|
|
||||||
|
if Path(ndwi_output_path).exists():
|
||||||
|
print(f"检测到已存在的NDWI掩膜文件,直接使用: {ndwi_output_path}")
|
||||||
|
notify("skipped", f"水域掩膜已设置: {ndwi_output_path}")
|
||||||
|
return ndwi_output_path
|
||||||
|
|
||||||
|
ndwi(img_path, ndwi_threshold, ndwi_output_path)
|
||||||
|
|
||||||
|
if generate_png:
|
||||||
|
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||||
|
generate_water_mask_overlay(
|
||||||
|
img_path=img_path, mask_path=ndwi_output_path, output_path=str(overlay_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"NDWI水体掩膜已生成: {ndwi_output_path}")
|
||||||
|
return ndwi_output_path
|
||||||
|
|
||||||
|
# ---- 必须提供 mask_path ----
|
||||||
|
if mask_path is None:
|
||||||
|
raise ValueError("必须提供 mask_path 参数或设置 use_ndwi=True")
|
||||||
|
if not Path(mask_path).exists():
|
||||||
|
raise ValueError(f"文件不存在: {mask_path}")
|
||||||
|
|
||||||
|
file_ext = Path(mask_path).suffix.lower()
|
||||||
|
|
||||||
|
# ---- SHP 栅格化 ----
|
||||||
|
if file_ext == ".shp":
|
||||||
|
if img_path is None:
|
||||||
|
raise ValueError("当 mask_path 为 shp 格式时,必须提供 img_path 参数")
|
||||||
|
|
||||||
|
print(f"检测到shp格式的水体掩膜,正在转换为dat格式...")
|
||||||
|
|
||||||
|
shp_output_path = output_path or str(water_mask_dir / "water_mask_from_shp.dat")
|
||||||
|
os.makedirs(Path(shp_output_path).parent, exist_ok=True)
|
||||||
|
|
||||||
|
if Path(shp_output_path).exists():
|
||||||
|
print(f"检测到已存在的栅格化掩膜文件,直接使用: {shp_output_path}")
|
||||||
|
notify("skipped", f"水域掩膜已设置: {shp_output_path}")
|
||||||
|
if generate_png:
|
||||||
|
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||||
|
if not overlay_path.exists():
|
||||||
|
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
|
||||||
|
return shp_output_path
|
||||||
|
|
||||||
|
safe_mask_path = os.path.abspath(mask_path).replace("\\", "/")
|
||||||
|
rasterize_shp(safe_mask_path, shp_output_path, img_path)
|
||||||
|
|
||||||
|
if generate_png:
|
||||||
|
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||||
|
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
|
||||||
|
|
||||||
|
notify("completed", f"dat格式水域掩膜已生成: {shp_output_path}")
|
||||||
|
return shp_output_path
|
||||||
|
|
||||||
|
# ---- 栅格格式直接使用 ----
|
||||||
|
print(f"检测到栅格格式的水体掩膜,直接使用: {mask_path}")
|
||||||
|
if generate_png and img_path is not None and Path(img_path).exists():
|
||||||
|
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||||
|
generate_water_mask_overlay(img_path, mask_path, str(overlay_path))
|
||||||
|
|
||||||
|
notify("completed", f"水域掩膜已设置: {mask_path}")
|
||||||
|
return mask_path
|
||||||
File diff suppressed because it is too large
Load Diff
@ -209,7 +209,7 @@ class Step3Panel(QWidget):
|
|||||||
"输出影像:",
|
"输出影像:",
|
||||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||||
)
|
)
|
||||||
self.output_file.line_edit.setPlaceholderText("deglint_image.dat")
|
self.output_file.line_edit.setPlaceholderText("deglint_image.bsq")
|
||||||
layout.addWidget(self.output_file)
|
layout.addWidget(self.output_file)
|
||||||
|
|
||||||
# 启用步骤
|
# 启用步骤
|
||||||
@ -301,7 +301,7 @@ class Step3Panel(QWidget):
|
|||||||
if self.work_dir:
|
if self.work_dir:
|
||||||
output_dir = os.path.join(self.work_dir, "3_deglint")
|
output_dir = os.path.join(self.work_dir, "3_deglint")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
default_output_path = os.path.join(output_dir, "deglint_image.dat").replace('\\', '/')
|
default_output_path = os.path.join(output_dir, "deglint_image.bsq").replace('\\', '/')
|
||||||
self.output_file.set_path(default_output_path)
|
self.output_file.set_path(default_output_path)
|
||||||
else:
|
else:
|
||||||
self.output_file.set_path("")
|
self.output_file.set_path("")
|
||||||
|
|||||||
@ -17,6 +17,57 @@ from src.gui.components.custom_widgets import FileSelectWidget
|
|||||||
from src.gui.styles import ModernStylesheet
|
from src.gui.styles import ModernStylesheet
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 中文映射表(内部键名 -> 显示文本)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 预处理方法:内部键 -> 显示文本
|
||||||
|
PREPROC_CHINESE = {
|
||||||
|
'None': '无 (None)',
|
||||||
|
'MMS': '最小-最大归一化 (MMS)',
|
||||||
|
'SS': '标度化 (SS)',
|
||||||
|
'SNV': '标准正态变换 (SNV)',
|
||||||
|
'MA': '移动平均 (MA)',
|
||||||
|
'SG': 'Savitzky-Golay (SG)',
|
||||||
|
'MSC': '多元散射校正 (MSC)',
|
||||||
|
'D1': '一阶导数 (D1)',
|
||||||
|
'D2': '二阶导数 (D2)',
|
||||||
|
'DT': '去趋势 (DT)',
|
||||||
|
'CT': '中心化 (CT)',
|
||||||
|
}
|
||||||
|
|
||||||
|
# 模型类型:内部键 -> 显示文本
|
||||||
|
MODEL_CHINESE = {
|
||||||
|
# 线性模型
|
||||||
|
'LinearRegression': '多元线性回归 (MLR)',
|
||||||
|
'Ridge': '岭回归 (Ridge)',
|
||||||
|
'Lasso': '套索回归 (Lasso)',
|
||||||
|
'ElasticNet': '弹性网络 (ElasticNet)',
|
||||||
|
'PLS': '偏最小二乘 (PLSR)',
|
||||||
|
# 树模型
|
||||||
|
'DecisionTree': '决策树 (CART)',
|
||||||
|
'RF': '随机森林 (RF)',
|
||||||
|
'ExtraTrees': '极端随机树 (ET)',
|
||||||
|
'XGBoost': '极值梯度提升 (XGBoost)',
|
||||||
|
'LightGBM': '轻量梯度提升 (LightGBM)',
|
||||||
|
'CatBoost': '类别梯度提升 (CatBoost)',
|
||||||
|
# 集成学习
|
||||||
|
'GradientBoosting': '梯度提升树 (GBDT)',
|
||||||
|
'AdaBoost': '自适应提升 (AdaBoost)',
|
||||||
|
# 其他模型
|
||||||
|
'SVR': '支持向量回归 (SVR)',
|
||||||
|
'KNN': 'K近邻回归 (KNN)',
|
||||||
|
'MLP': '多层感知机 (BP神经网络)',
|
||||||
|
}
|
||||||
|
|
||||||
|
# 数据划分方法:内部键 -> 显示文本
|
||||||
|
SPLIT_CHINESE = {
|
||||||
|
'spxy': 'SPXY 算法 (考量X-Y空间)',
|
||||||
|
'ks': 'KS 算法 (考量X空间)',
|
||||||
|
'random': '随机划分 (Random)',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Step6Panel(QWidget):
|
class Step6Panel(QWidget):
|
||||||
"""步骤6:机器学习建模"""
|
"""步骤6:机器学习建模"""
|
||||||
def __init__(self, parent=None):
|
def __init__(self, parent=None):
|
||||||
@ -54,7 +105,7 @@ class Step6Panel(QWidget):
|
|||||||
|
|
||||||
# 启用步骤
|
# 启用步骤
|
||||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||||
self.enable_checkbox.setChecked(True)
|
self.enable_checkbox.setChecked(False)
|
||||||
layout.addWidget(self.enable_checkbox)
|
layout.addWidget(self.enable_checkbox)
|
||||||
|
|
||||||
# 独立运行按钮
|
# 独立运行按钮
|
||||||
@ -95,8 +146,8 @@ class Step6Panel(QWidget):
|
|||||||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||||||
|
|
||||||
for i, method in enumerate(preproc_methods):
|
for i, method in enumerate(preproc_methods):
|
||||||
checkbox = QCheckBox(method)
|
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
|
||||||
checkbox.setChecked(True)
|
checkbox.setChecked(False)
|
||||||
self.preproc_checkboxes[method] = checkbox
|
self.preproc_checkboxes[method] = checkbox
|
||||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||||
|
|
||||||
@ -122,10 +173,10 @@ class Step6Panel(QWidget):
|
|||||||
self.model_checkboxes = {}
|
self.model_checkboxes = {}
|
||||||
|
|
||||||
model_groups = [
|
model_groups = [
|
||||||
("线性模型", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
||||||
("树模型", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
||||||
("集成学习", ['GradientBoosting', 'AdaBoost']),
|
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
|
||||||
("其他模型", ['SVR', 'KNN', 'MLP'])
|
("【其他模型】", ['SVR', 'KNN', 'MLP'])
|
||||||
]
|
]
|
||||||
|
|
||||||
row = 0
|
row = 0
|
||||||
@ -140,8 +191,8 @@ class Step6Panel(QWidget):
|
|||||||
row += 1
|
row += 1
|
||||||
|
|
||||||
for i, model in enumerate(models):
|
for i, model in enumerate(models):
|
||||||
checkbox = QCheckBox(model)
|
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
|
||||||
checkbox.setChecked(model in ['SVR', 'RF', 'Ridge', 'Lasso'])
|
checkbox.setChecked(False)
|
||||||
self.model_checkboxes[model] = checkbox
|
self.model_checkboxes[model] = checkbox
|
||||||
model_grid.addWidget(checkbox, row, i % 4)
|
model_grid.addWidget(checkbox, row, i % 4)
|
||||||
if (i + 1) % 4 == 0:
|
if (i + 1) % 4 == 0:
|
||||||
@ -172,8 +223,8 @@ class Step6Panel(QWidget):
|
|||||||
split_methods = ['spxy', 'ks', 'random']
|
split_methods = ['spxy', 'ks', 'random']
|
||||||
|
|
||||||
for i, method in enumerate(split_methods):
|
for i, method in enumerate(split_methods):
|
||||||
checkbox = QCheckBox(method)
|
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
|
||||||
checkbox.setChecked(True)
|
checkbox.setChecked(False)
|
||||||
self.split_checkboxes[method] = checkbox
|
self.split_checkboxes[method] = checkbox
|
||||||
split_grid.addWidget(checkbox, 0, i)
|
split_grid.addWidget(checkbox, 0, i)
|
||||||
|
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class Step9Panel(QWidget):
|
|||||||
mode_row.addStretch()
|
mode_row.addStretch()
|
||||||
layout.addLayout(mode_row)
|
layout.addLayout(mode_row)
|
||||||
|
|
||||||
# ---------- RadioButton 美化样式(选中状态更醒目) ----------
|
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
|
||||||
radio_style = """
|
radio_style = """
|
||||||
QRadioButton {
|
QRadioButton {
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
@ -117,21 +117,16 @@ class Step9Panel(QWidget):
|
|||||||
color: #333333;
|
color: #333333;
|
||||||
}
|
}
|
||||||
QRadioButton::indicator {
|
QRadioButton::indicator {
|
||||||
width: 18px;
|
width: 16px;
|
||||||
height: 18px;
|
height: 16px;
|
||||||
border: 2px solid #999999;
|
border: 2px solid #999999;
|
||||||
border-radius: 9px;
|
border-radius: 3px;
|
||||||
background-color: white;
|
background-color: white;
|
||||||
}
|
}
|
||||||
QRadioButton::indicator:checked {
|
QRadioButton::indicator:checked {
|
||||||
border: 2px solid #0078d4;
|
border: 2px solid #0078d4;
|
||||||
background-color: qradialgradient(
|
background-color: #0078d4;
|
||||||
cx:0.5, cy:0.5, radius:0.5,
|
image: none;
|
||||||
fx:0.5, fy:0.5,
|
|
||||||
stop:0 #0078d4,
|
|
||||||
stop:0.6 white,
|
|
||||||
stop:1.0 white
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
QRadioButton::indicator:hover {
|
QRadioButton::indicator:hover {
|
||||||
border: 2px solid #005a9e;
|
border: 2px solid #005a9e;
|
||||||
@ -353,7 +348,7 @@ class Step9Panel(QWidget):
|
|||||||
if not main_window:
|
if not main_window:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(优先)
|
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
|
||||||
pred_dir = None
|
pred_dir = None
|
||||||
if hasattr(main_window, 'step8_panel'):
|
if hasattr(main_window, 'step8_panel'):
|
||||||
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
|
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
|
||||||
@ -367,7 +362,10 @@ class Step9Panel(QWidget):
|
|||||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||||
if not os.path.isabs(step8_output):
|
if not os.path.isabs(step8_output):
|
||||||
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
|
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
|
||||||
pred_dir = str(Path(step8_output).parent)
|
# 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录)
|
||||||
|
base_pred_dir = str(Path(step8_output).parent)
|
||||||
|
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
|
||||||
|
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
|
||||||
|
|
||||||
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
|
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
|
||||||
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
|
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
|
||||||
@ -411,6 +409,14 @@ class Step9Panel(QWidget):
|
|||||||
existing_out = self.output_dir.get_path()
|
existing_out = self.output_dir.get_path()
|
||||||
if not existing_out or not existing_out.strip():
|
if not existing_out or not existing_out.strip():
|
||||||
self.output_dir.set_path(output_dir)
|
self.output_dir.set_path(output_dir)
|
||||||
|
|
||||||
|
# 5. 自动继承步骤1的水域掩膜作为边界文件
|
||||||
|
if self.work_dir:
|
||||||
|
default_mask = Path(self.work_dir) / "1_water_mask" / "water_mask_from_shp.dat"
|
||||||
|
if default_mask.exists():
|
||||||
|
existing_boundary = (self.boundary_file.get_path() or "").strip()
|
||||||
|
if not existing_boundary:
|
||||||
|
self.boundary_file.set_path(str(default_mask))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||||
|
|||||||
@ -1825,6 +1825,11 @@ class WaterQualityGUI(QMainWindow):
|
|||||||
for step_id, step_display in steps:
|
for step_id, step_display in steps:
|
||||||
item = QListWidgetItem(f" └─ {step_display}")
|
item = QListWidgetItem(f" └─ {step_display}")
|
||||||
item.setData(Qt.UserRole, step_id)
|
item.setData(Qt.UserRole, step_id)
|
||||||
|
|
||||||
|
# 隐藏4个冗余回归步骤(树节点)
|
||||||
|
if step_id in ("step6_5", "step6_75", "step8_5", "step8_75"):
|
||||||
|
item.setHidden(True)
|
||||||
|
|
||||||
self.step_name_map[step_display] = step_id
|
self.step_name_map[step_display] = step_id
|
||||||
|
|
||||||
# 设置步骤项的样式
|
# 设置步骤项的样式
|
||||||
@ -1905,9 +1910,11 @@ class WaterQualityGUI(QMainWindow):
|
|||||||
|
|
||||||
self.step6_5_panel = Step6_5Panel()
|
self.step6_5_panel = Step6_5Panel()
|
||||||
self.step_stack.addTab(self.create_scroll_area(self.step6_5_panel), QIcon(self.get_icon_path("6.png")), "回归建模")
|
self.step_stack.addTab(self.create_scroll_area(self.step6_5_panel), QIcon(self.get_icon_path("6.png")), "回归建模")
|
||||||
|
self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab
|
||||||
|
|
||||||
self.step6_75_panel = Step6_75Panel()
|
self.step6_75_panel = Step6_75Panel()
|
||||||
self.step_stack.addTab(self.create_scroll_area(self.step6_75_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模")
|
self.step_stack.addTab(self.create_scroll_area(self.step6_75_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模")
|
||||||
|
self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab
|
||||||
|
|
||||||
self.step7_panel = Step7Panel()
|
self.step7_panel = Step7Panel()
|
||||||
self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("7.png")), "采样点布设")
|
self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("7.png")), "采样点布设")
|
||||||
@ -1917,9 +1924,11 @@ class WaterQualityGUI(QMainWindow):
|
|||||||
|
|
||||||
self.step8_5_panel = Step8_5Panel()
|
self.step8_5_panel = Step8_5Panel()
|
||||||
self.step_stack.addTab(self.create_scroll_area(self.step8_5_panel), QIcon(self.get_icon_path("8.png")), "回归预测")
|
self.step_stack.addTab(self.create_scroll_area(self.step8_5_panel), QIcon(self.get_icon_path("8.png")), "回归预测")
|
||||||
|
self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab
|
||||||
|
|
||||||
self.step8_75_panel = Step8_75Panel()
|
self.step8_75_panel = Step8_75Panel()
|
||||||
self.step_stack.addTab(self.create_scroll_area(self.step8_75_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测")
|
self.step_stack.addTab(self.create_scroll_area(self.step8_75_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测")
|
||||||
|
self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab
|
||||||
|
|
||||||
self.step9_panel = Step9Panel()
|
self.step9_panel = Step9Panel()
|
||||||
self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")
|
self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")
|
||||||
|
|||||||
@ -1003,67 +1003,84 @@ class ReportGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
保存的文件路径
|
保存的文件路径
|
||||||
"""
|
"""
|
||||||
from modeling_batch import WaterQualityModelingBatch
|
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||||
|
import joblib
|
||||||
|
|
||||||
modeler = WaterQualityModelingBatch(models_dir)
|
modeler = WaterQualityModelingBatch(models_dir)
|
||||||
|
|
||||||
# 需要先加载训练结果
|
|
||||||
# 这里假设results已经存储在modeler中,或者需要从保存的文件中读取
|
|
||||||
# 由于modeling_batch.py的结构,我们需要另一种方式来获取所有结果
|
|
||||||
|
|
||||||
# 尝试遍历模型目录,查找所有保存的结果
|
|
||||||
models_path = Path(models_dir)
|
models_path = Path(models_dir)
|
||||||
all_results = []
|
all_results = []
|
||||||
|
|
||||||
# 遍历所有目标参数文件夹
|
# 递归扫描 *.joblib 和 *.pkl,兼容 artifacts_dir/target_name/ 的所有子目录层级
|
||||||
for target_folder in models_path.iterdir():
|
model_files = list(models_path.rglob("*.joblib")) + list(models_path.rglob("*.pkl"))
|
||||||
if not target_folder.is_dir():
|
|
||||||
continue
|
for model_file in model_files:
|
||||||
|
# 目标参数取直系父目录名(符合 artifacts_dir/target_name/ 结构)
|
||||||
target_name = target_folder.name
|
target_name = model_file.parent.name
|
||||||
|
stem = model_file.stem
|
||||||
# 查找所有模型文件
|
|
||||||
for model_file in target_folder.rglob("*.pkl"):
|
# 文件名格式:{safe_target}_{preprocess}_{model_name}.joblib
|
||||||
# 从文件名提取信息(假设格式为:{preprocess}_{model}_{split}.pkl)
|
# 使用 split('_', 2) 最多切 3 段,目标 1 段、预处理 1 段、模型 1 段
|
||||||
model_info = {
|
parts = stem.split('_', 2)
|
||||||
'target': target_name,
|
preprocess = parts[1] if len(parts) > 1 else 'Unknown'
|
||||||
'model_file': str(model_file),
|
model_name_str = parts[2] if len(parts) > 2 else stem
|
||||||
'preprocess': 'Unknown',
|
|
||||||
'model': 'Unknown',
|
# 尝试从 joblib/pkl 读取元数据,提取性能指标
|
||||||
'split_method': 'Unknown'
|
metrics = {}
|
||||||
|
try:
|
||||||
|
data = joblib.load(model_file)
|
||||||
|
metadata = data.get('metadata', {})
|
||||||
|
metrics = {
|
||||||
|
'train_r2': metadata.get('train_r2', 'N/A'),
|
||||||
|
'test_r2': metadata.get('test_r2', 'N/A'),
|
||||||
|
'test_rmse': metadata.get('test_rmse', 'N/A'),
|
||||||
|
'train_rmse': metadata.get('train_rmse', 'N/A'),
|
||||||
|
'train_mae': metadata.get('train_mae', 'N/A'),
|
||||||
|
'test_mae': metadata.get('test_mae', 'N/A'),
|
||||||
|
'cv_mean': metadata.get('cv_mean', 'N/A'),
|
||||||
}
|
}
|
||||||
|
except Exception:
|
||||||
# 尝试从文件名解析
|
pass # 加载失败时 metrics 保持为空字典,摘要中该列为 N/A
|
||||||
parts = model_file.stem.split('_')
|
|
||||||
if len(parts) >= 3:
|
all_results.append({
|
||||||
model_info['preprocess'] = parts[0]
|
'target': target_name,
|
||||||
model_info['model'] = parts[1]
|
'model_file': str(model_file),
|
||||||
model_info['split_method'] = parts[2]
|
'preprocess': preprocess,
|
||||||
|
'model': model_name_str,
|
||||||
all_results.append(model_info)
|
**metrics,
|
||||||
|
})
|
||||||
# 如果有训练结果数据,使用实际指标
|
|
||||||
# 否则创建一个基本的摘要
|
|
||||||
summary_data = []
|
summary_data = []
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
summary_data.append({
|
summary_data.append({
|
||||||
'目标参数': result['target'],
|
'目标参数': result['target'],
|
||||||
'预处理方法': result['preprocess'],
|
'预处理方法': result['preprocess'],
|
||||||
'模型名称': result['model'],
|
'模型名称': result['model'],
|
||||||
'划分方法': result['split_method'],
|
'模型文件': result['model_file'],
|
||||||
'模型文件': result['model_file']
|
'训练集R²': result.get('train_r2', 'N/A'),
|
||||||
|
'测试集R²': result.get('test_r2', 'N/A'),
|
||||||
|
'测试集RMSE': result.get('test_rmse', 'N/A'),
|
||||||
|
'训练集RMSE': result.get('train_rmse', 'N/A'),
|
||||||
|
'训练集MAE': result.get('train_mae', 'N/A'),
|
||||||
|
'测试集MAE': result.get('test_mae', 'N/A'),
|
||||||
|
'CV均值': result.get('cv_mean', 'N/A'),
|
||||||
})
|
})
|
||||||
|
|
||||||
if not summary_data:
|
if not summary_data:
|
||||||
print("警告:未找到模型文件,生成空摘要")
|
print("警告:未找到模型文件,生成空摘要")
|
||||||
summary_data = [{
|
summary_data = [{
|
||||||
'目标参数': 'No Data',
|
'目标参数': 'No Data',
|
||||||
'预处理方法': 'N/A',
|
'预处理方法': 'N/A',
|
||||||
'模型名称': 'N/A',
|
'模型名称': 'N/A',
|
||||||
'划分方法': 'N/A',
|
'模型文件': 'N/A',
|
||||||
'模型文件': 'N/A'
|
'训练集R²': 'N/A',
|
||||||
|
'测试集R²': 'N/A',
|
||||||
|
'测试集RMSE': 'N/A',
|
||||||
|
'训练集RMSE': 'N/A',
|
||||||
|
'训练集MAE': 'N/A',
|
||||||
|
'测试集MAE': 'N/A',
|
||||||
|
'CV均值': 'N/A',
|
||||||
}]
|
}]
|
||||||
|
|
||||||
df_summary = pd.DataFrame(summary_data)
|
df_summary = pd.DataFrame(summary_data)
|
||||||
|
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
|
|||||||
@ -96,8 +96,14 @@ class BandMathCalculator:
|
|||||||
|
|
||||||
print(f"计算表达式: {calc_expression}")
|
print(f"计算表达式: {calc_expression}")
|
||||||
|
|
||||||
# 安全地计算表达式
|
# 【新增安全防护】引入 numpy 命名空间,让 eval 引擎安全识别 nan 与 inf
|
||||||
result = eval(calc_expression)
|
import numpy as np
|
||||||
|
try:
|
||||||
|
# 即使 calc_expression 含有纯字符 nan,也能被 np.nan 安全接管
|
||||||
|
result = eval(calc_expression, {"__builtins__": None}, {"nan": np.nan, "inf": np.inf, "np": np})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ 警告:公式计算异常 ({e}),该点赋值为 nan")
|
||||||
|
result = np.nan
|
||||||
|
|
||||||
# 返回结果
|
# 返回结果
|
||||||
if var_name:
|
if var_name:
|
||||||
|
|||||||
Reference in New Issue
Block a user