Step3 插值算法 OOM 修复 + 多进程加速 + 全链路累积改动(14 文件)

This commit is contained in:
DXC
2026-06-15 16:49:17 +08:00
parent 82e0b92af6
commit 60a9d7d922
14 changed files with 855 additions and 152 deletions

View File

@ -3,8 +3,24 @@
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
支持多种插值方法nearest, bilinear, spline (RBF), kriging。
本模块使用多进程并行分块 IO 加速Plan A
- ProcessPoolExecutor 为每个 worker 进程打开一次源影像initializer 阶段),
避免每块重复 gdal.Open 带来的开销Windows 上 ~50ms/次)
- 主进程统一负责输出文件的写入,避免多进程写锁竞争
- 分块大小block_size默认 1024内存充足可调至 2048 / 4096
注意:
- GDAL Dataset / Rasterio Dataset 对象不能跨进程传递picking 不支持),
所以 worker 必须在 init 阶段自己独立打开源文件
- 每个 worker 强制设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 多线程
造成的 CPU 过订阅
- 关闭多进程:传 ``use_multiprocessing=False`` 或 ``n_workers=1``
"""
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import numpy as np
from typing import Optional, Union, Tuple, List
from pathlib import Path
@ -24,6 +40,9 @@ except ImportError:
GDAL_AVAILABLE = False
_worker_dataset: Optional["gdal.Dataset"] = None
def interpolate_pixels(
image_stack: np.ndarray,
zero_coords: np.ndarray,
@ -52,7 +71,6 @@ def interpolate_pixels(
height, width, n_bands = image_stack.shape
result = image_stack.copy()
# 兼容中文和各种格式的method参数
raw_method = str(interpolation_method).lower()
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
method = 'nearest'
@ -181,39 +199,271 @@ def _interpolate_single_band(
return np.zeros(len(zero_coords))
def _normalize_interpolation_method(method: str) -> str:
"""将中文/英文混用的插值方法名归一化为内部标准名
支持: 'nearest'/'邻近'/'最邻近''bilinear'/'线性'/'双线性'
'spline'/'样条'/'rbf''kriging'/'克里金'
"""
raw = str(method).lower()
if 'nearest' in raw or '邻近' in raw or '最邻近' in raw:
return 'nearest'
if 'bilinear' in raw or '线性' in raw or '双线性' in raw:
return 'bilinear'
if 'spline' in raw or '样条' in raw or 'rbf' in raw:
return 'spline'
if 'kriging' in raw or '克里金' in raw:
return 'kriging'
return 'nearest'
def _read_water_mask_to_array(
water_mask: Optional[Union[str, np.ndarray]],
expected_height: int,
expected_width: int,
) -> Optional[np.ndarray]:
"""读取水域掩膜为 numpy 数组单波段bool/int 均可)
None 或空字符串直接返回 None。形状不匹配时给出告警但不抛错
让调用方按"无掩膜"路径继续。
"""
if water_mask is None:
return None
if isinstance(water_mask, str):
if not water_mask.strip():
return None
mask_ds = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_ds is None:
print(f" [warn] 无法打开水域掩膜 {water_mask},按无掩膜处理")
return None
try:
mask_array = mask_ds.GetRasterBand(1).ReadAsArray()
finally:
mask_ds = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
else:
return None
if mask_array.shape != (expected_height, expected_width):
print(
f" [warn] 水域掩膜形状 {mask_array.shape} 与影像 "
f"({expected_height}, {expected_width}) 不匹配,按无掩膜处理"
)
return None
return mask_array
def _init_worker(img_path: str) -> None:
"""ProcessPoolExecutor initializer: 每个 worker 进程只调用一次
在 worker 进程启动时打开源影像 dataset 并缓存在模块全局变量
``_worker_dataset`` 中。后续所有块处理直接复用这个 dataset
避免每块重复 ``gdal.Open``Windows 上约 50ms/次100 块即 5s
同时设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 默认多线程
造成的 CPU 过订阅。
"""
global _worker_dataset
gdal.SetConfigOption('GDAL_NUM_THREADS', '1')
if hasattr(gdal, 'UseExceptions'):
gdal.UseExceptions()
_worker_dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if _worker_dataset is None:
raise RuntimeError(f"Worker failed to open source image: {img_path}")
def _interpolate_block_worker(task: tuple) -> tuple:
"""ProcessPoolExecutor worker: 处理单个块并返回结果
该函数必须保持模块级(可被 pickle不持有任何外部状态——
源 dataset 通过 ``_worker_dataset`` 模块全局变量获取。
Returns:
``(x0, y0, inner_bands, zero_count, error_msg)`` 元组:
- x0, y0: 块在影像中的写入起点
- inner_bands: ``List[np.ndarray]``,每个元素是 (inner_h, inner_w)
float32 数组(每个波段一个),或失败时为 None
- zero_count: 该扩展块中识别到的零像素数(含 halo 范围)
- error_msg: None 表示成功str 表示错误信息
"""
(
x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
) = task
if _worker_dataset is None:
return (x0, y0, None, 0, "Worker dataset not initialized")
try:
inner_bands, zero_count = _process_one_block(
_worker_dataset, x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
)
return (x0, y0, inner_bands, zero_count, None)
except Exception as e:
return (x0, y0, None, 0, str(e))
def _process_one_block(
dataset: "gdal.Dataset",
x0: int, y0: int,
ey0: int, ex0: int, ey1: int, ex1: int,
row_offset: int, col_offset: int,
inner_h: int, inner_w: int,
mask_segment_ext: Optional[np.ndarray],
method: str,
) -> Tuple[List[np.ndarray], int]:
"""处理单个扩展块纯计算核心dataset 显式传入)
串行模式和并行模式共用此函数。并行模式下 dataset 来自 worker 的
缓存(``_worker_dataset``),串行模式下 dataset 由主函数传入。
Args:
dataset: 已打开的源影像 dataset
x0, y0: 内部块左上角(写入位置)
ey0, ex0, ey1, ex1: 扩展块(含 halo坐标
row_offset, col_offset: 内部块在扩展块中的偏移
inner_h, inner_w: 内部块尺寸
mask_segment_ext: 扩展块对应的水域掩膜None 表示不应用)
method: 插值方法(已归一化)
Returns:
``(inner_bands, zero_count)`` 元组:
- inner_bands: ``List[np.ndarray]``,长度 = n_bands每个元素形状为
``(inner_h, inner_w)`` 的 float32 数组
- zero_count: 扩展块中识别到的零像素数
"""
n_bands = dataset.RasterCount
ext_bands: List[np.ndarray] = []
for b in range(1, n_bands + 1):
band = dataset.GetRasterBand(b)
ext_bands.append(
band.ReadAsArray(ex0, ey0, ex1 - ex0, ey1 - ey0).astype(np.float32)
)
band = None
try:
ext_h, ext_w = ey1 - ey0, ex1 - ex0
all_zero_ext = np.ones((ext_h, ext_w), dtype=bool)
for b_data in ext_bands:
all_zero_ext &= (b_data == 0)
if mask_segment_ext is not None:
all_zero_ext &= (mask_segment_ext > 0)
zero_count = int(np.sum(all_zero_ext))
if zero_count == 0:
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, 0
zero_y, zero_x = np.where(all_zero_ext)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_zero_ext
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
print(
f" [warn] 块 (y={y0}-{y0 + inner_h}, x={x0}-{x0 + inner_w}) "
f"无有效像素可作插值上下文,已跳过"
)
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, zero_count
for b in range(n_bands):
ext_band = ext_bands[b]
valid_values_band = ext_band[valid_mask]
if len(valid_values_band) == 0:
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, method
)
ext_band[zero_y, zero_x] = band_result
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, zero_count
finally:
del ext_bands
def interpolate_zero_pixels_batch(
img_path: str,
interpolation_method: str = 'nearest',
output_path: Optional[str] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
deglint_dir: Optional[str] = None,
callback_progress: Optional[callable] = None
callback_progress: Optional[callable] = None,
block_size: int = 1024,
halo_size: int = 64,
n_workers: Optional[int] = None,
use_multiprocessing: bool = True,
) -> Tuple[str, Optional[np.ndarray]]:
"""
对影像中所有波段都为0的像素点进行插值完整流程含文件I/O
对影像中所有波段都为0的像素点进行插值完整流程含文件I/O
采用 **分块 IO + 多进程并行** 策略:
1. 影像按 ``block_size`` × ``block_size`` 分块,每块边界外扩展
``halo_size`` 像素作为插值上下文,避免块边缘插值退化
2. 多进程并行(默认 ``ProcessPoolExecutor``worker 数 = CPU 核心数)
并发处理所有块GDAL Dataset 不能跨进程传递,所以每个 worker
在 ``initializer`` 阶段独立打开源文件一次并缓存
3. 主进程按块序接收处理结果并统一写入输出文件,避免写锁竞争
4. 该方案可彻底避免一次性读取 50 波段整景影像时的 OOM 隐患
50 波段 × 4000×4000 × float32 ≈ 3GB 的 np.dstack
Args:
img_path: 输入影像文件路径
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
output_path: 输出文件路径如果为None自动生成
water_mask: 水域掩膜(文件路径或数组
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline',
'kriging' 及其中文别名('邻近'/'最邻近'/'线性'/'双线性'/'样条'/'克里金'
output_path: 输出文件路径(如果为 None 且 deglint_dir 提供,自动生成
water_mask: 水域掩膜(文件路径或数组),形状须与影像高宽一致
deglint_dir: 去耀斑目录(用于生成默认输出路径)
callback_progress: 进度回调函数
callback_progress: 进度回调函数,签名 ``callback(msg: str)``
block_size: 分块大小(像素),默认 1024内存充足可调 2048/4096
halo_size: 上下文 halo 宽度(像素),默认 64
n_workers: 并行 worker 进程数None = ``multiprocessing.cpu_count()``
传 1 等价于串行模式
use_multiprocessing: 是否启用多进程False 时强制串行
Returns:
(output_path, interpolated_image_stack) 元组
``(output_path, None)`` 元组。第二个值固定为 ``None``(与原版语义保留
兼容;返回完整内存堆叠会重新引入 OOM 风险,故不再提供)。
"""
if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值")
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
# 确定输出路径
if output_path is None and deglint_dir is not None:
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
method = _normalize_interpolation_method(interpolation_method)
# 检查文件是否已存在
if output_path and Path(output_path).exists():
if output_path is None and deglint_dir is not None:
output_path = str(Path(deglint_dir) / f"interpolated_{method}.bsq")
if output_path is None:
raise ValueError("output_path 和 deglint_dir 至少需要指定一个")
if Path(output_path).exists():
return output_path, None
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
@ -227,94 +477,125 @@ def interpolate_zero_pixels_batch(
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
# 读取所有波段数据
all_bands = []
for band_idx in range(1, n_bands + 1):
band = dataset.GetRasterBand(band_idx)
band_data = band.ReadAsArray().astype(np.float32)
all_bands.append(band_data)
image_stack = np.dstack(all_bands)
# 读取水域掩膜
mask_array = None
if water_mask is not None:
if isinstance(water_mask, str):
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
# 找出所有波段都为0的像素点
all_bands_zero = np.all(image_stack == 0, axis=2)
if mask_array is not None:
all_bands_zero = all_bands_zero & (mask_array > 0)
zero_pixel_count = np.sum(all_bands_zero)
if zero_pixel_count == 0:
# 无需插值,直接保存
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(all_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
return output_path, image_stack
# 获取坐标
zero_y, zero_x = np.where(all_bands_zero)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_bands_zero
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
raise ValueError("没有有效像素可用于插值")
# 逐波段插值
interpolated_bands = []
for band_idx in range(n_bands):
if callback_progress:
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
band_data = all_bands[band_idx].copy()
valid_values_band = band_data[valid_mask]
if len(valid_values_band) == 0:
interpolated_bands.append(band_data)
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, interpolation_method
if width <= 0 or height <= 0 or n_bands <= 0:
raise ValueError(
f"影像尺寸异常: width={width}, height={height}, n_bands={n_bands}"
)
band_data[all_bands_zero] = band_result
interpolated_bands.append(band_data)
# 保存结果
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(interpolated_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
mask_array = _read_water_mask_to_array(water_mask, height, width)
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise RuntimeError("未找到可用的栅格驱动ENVI / GTiff 都不存在)")
out_dataset = driver.Create(
output_path, width, height, n_bands, gdal.GDT_Float32
)
if out_dataset is None:
raise RuntimeError(f"无法创建输出文件: {output_path}")
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
try:
if not use_multiprocessing:
effective_workers = 1
elif n_workers is not None and n_workers >= 1:
effective_workers = int(n_workers)
else:
try:
cpu_count = multiprocessing.cpu_count() or 1
except (NotImplementedError, OSError):
cpu_count = 1
effective_workers = max(1, cpu_count)
n_blocks_y = (height + block_size - 1) // block_size
n_blocks_x = (width + block_size - 1) // block_size
total_blocks = n_blocks_y * n_blocks_x
tasks = []
for by in range(n_blocks_y):
y0 = by * block_size
y1 = min(y0 + block_size, height)
inner_h = y1 - y0
ey0 = max(0, y0 - halo_size)
ey1 = min(height, y1 + halo_size)
for bx in range(n_blocks_x):
x0 = bx * block_size
x1 = min(x0 + block_size, width)
inner_w = x1 - x0
ex0 = max(0, x0 - halo_size)
ex1 = min(width, x1 + halo_size)
row_offset = y0 - ey0
col_offset = x0 - ex0
mask_segment_ext = None
if mask_array is not None:
mask_segment_ext = mask_array[ey0:ey1, ex0:ex1]
tasks.append((
x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
))
if callback_progress:
callback_progress(
f"分块插值开始: 共 {total_blocks}"
f"(block_size={block_size}, halo={halo_size}, method={method}, "
f"workers={effective_workers})"
)
total_zero_pixels = 0
if effective_workers <= 1:
for block_idx, task in enumerate(tasks, 1):
x0_t, y0_t = task[0], task[1]
if callback_progress:
callback_progress(
f"{block_idx}/{total_blocks} "
f"y=[{y0_t},{y0_t + task[8]}) x=[{x0_t},{x0_t + task[9]})"
)
inner_bands, zero_count = _process_one_block(
dataset, *task
)
for b_idx, band_data in enumerate(inner_bands):
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
band_data, xoff=x0_t, yoff=y0_t
)
total_zero_pixels += zero_count
else:
with ProcessPoolExecutor(
max_workers=effective_workers,
initializer=_init_worker,
initargs=(img_path,),
) as executor:
futures = [
executor.submit(_interpolate_block_worker, task)
for task in tasks
]
for block_idx, future in enumerate(futures, 1):
x0_t, y0_t, inner_bands, zero_count, error = future.result()
if error is not None:
raise RuntimeError(
f"块 (y={y0_t}, x={x0_t}) 处理失败: {error}"
)
if inner_bands is not None:
for b_idx, band_data in enumerate(inner_bands):
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
band_data, xoff=x0_t, yoff=y0_t
)
total_zero_pixels += zero_count
if callback_progress:
callback_progress(f"已写入块 {block_idx}/{total_blocks}")
if callback_progress:
callback_progress(
f"分块插值完成: 共处理 {total_zero_pixels} 个零像素 "
f"{total_blocks} 块,方法 {method}workers={effective_workers}"
)
return output_path, None
finally:
out_dataset = None
result_stack = np.dstack(interpolated_bands)
return output_path, result_stack
finally:
dataset = None

View File

@ -899,7 +899,7 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
if __name__ == '__main__':
# 在这里直接设置参数
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均

View File

@ -806,8 +806,8 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
if __name__ == '__main__':
# 在这里直接设置参数
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\work_dir\6_Spectral_Feature_Extraction/yangdian_output.csv" # CSV格式输出文件路径
radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径可选None表示不使用

View File

@ -315,7 +315,7 @@ def main():
# 示例1: 使用所有回归方法分析光谱指数
print("\n1. 光谱指数与叶绿素a的回归分析:")
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv")
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\water_quality_results.csv")
spectral_indices = ['Al10SABI','Am092Bsub']
results1 = analyzer.batch_single_variable_regression(
@ -323,7 +323,7 @@ def main():
x_columns=spectral_indices,
y_column='Chlorophyll',
methods='all',
output_file=r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\spectral_indices_regression.csv'
output_file=r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\spectral_indices_regression.csv'
)
# # 示例2: 使用特定方法分析反射率波段
@ -343,7 +343,7 @@ def main():
best_models = analyzer.get_best_models_summary()
if not best_models.empty:
print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False))
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\best_models_summary.csv', index=False)
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\best_models_summary.csv', index=False)
print("\n最佳模型汇总已保存到 'best_models_summary.csv'")
#
# def advanced_usage_example():

View File

@ -246,7 +246,7 @@ def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, outp
if __name__ == "__main__":
algorithm= "chl_a"
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json"
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\8_non_empirical_models\SS\SS_chl_a.json"
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
wave_radius=5.0

View File

@ -98,7 +98,7 @@ PIPELINE_STEPS: List[StepSpec] = [
step_id="step4", method_name="step5_process_csv",
requires=["csv_path"], produces=["processed_csv_path"],
required_input_files=["csv_path"],
output_file="{work_dir}/4_processed_data/processed_data.csv",
output_file="{work_dir}/5_Data_Cleaning/processed_data.csv",
description="CSV 异常值清洗",
),
StepSpec(
@ -111,21 +111,21 @@ PIPELINE_STEPS: List[StepSpec] = [
},
skip_when_missing=False,
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
output_file="{work_dir}/5_training_spectra/training_spectra.csv",
output_file="{work_dir}/6_Spectral_Feature_Extraction/training_spectra.csv",
description="实测样本点光谱提取",
),
StepSpec(
step_id="step7", method_name="step7_calc_indices",
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
required_input_files=["training_csv_path"],
output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv",
output_file="{work_dir}/7_Water_Quality_Indices/training_spectra_indices.csv",
description="水质参数指数计算双轨输出A轨宽表 + B轨单文件",
),
StepSpec(
step_id="step8", method_name="step8_train_ml",
requires=["training_csv_path"], produces=["models_dir"],
required_input_files=["training_csv_path"],
output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl",
output_file="{work_dir}/8_Supervised_Model_Training/best_models.pkl",
description="ML 建模GridSearchCV / AutoML",
),
StepSpec(
@ -134,7 +134,7 @@ PIPELINE_STEPS: List[StepSpec] = [
requires=["training_csv_path"], produces=["models_dir"],
parameter_map={"training_csv_path": "csv_path"},
required_input_files=["training_csv_path"],
output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl",
output_file="{work_dir}/8_Non_Empirical_Regression/non_empirical_models.pkl",
description="非经验统计回归",
),
StepSpec(

View File

@ -328,7 +328,7 @@ def train_with_automl(
split_method = split_methods[0]
if output_dir is None:
output_dir = "./7_Supervised_Model_Training_AutoML"
output_dir = "./8_Supervised_Model_Training_AutoML"
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
preproc_dir = out_dir / preproc
@ -519,7 +519,7 @@ if __name__ == "__main__":
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
p.add_argument("--out", default="./8_Supervised_Model_Training_AutoML")
args = p.parse_args()
# 智能推断 feature_start_column 类型

View File

@ -21,7 +21,7 @@ class DataPreparationStep:
@staticmethod
def process_csv(
csv_path: str,
output_dir: Union[str, Path] = "./4_processed_data",
output_dir: Union[str, Path] = "./5_Data_Cleaning",
callback: Optional[Callable] = None,
) -> str:
"""处理CSV文件筛选剔除异常值"""
@ -61,7 +61,7 @@ class DataPreparationStep:
boundary_path: Optional[str] = None,
glint_mask_path: Optional[str] = None,
water_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./5_training_spectra",
output_dir: Union[str, Path] = "./6_Spectral_Feature_Extraction",
callback: Optional[Callable] = None,
) -> str:
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
@ -131,7 +131,7 @@ class DataPreparationStep:
formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None,
enabled: bool = True,
output_dir: Union[str, Path] = "./6_water_quality_indices",
output_dir: Union[str, Path] = "./7_Water_Quality_Indices",
callback: Optional[Callable] = None,
) -> Optional[str]:
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""

View File

@ -135,7 +135,7 @@ class ModelingStep:
split_methods: Optional[List[str]] = None,
cv_folds: int = 5,
training_csv_path: Optional[str] = None,
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
output_dir: Union[str, Path] = "./8_Supervised_Model_Training",
callback: Optional[Callable] = None,
_report_generator=None,
) -> str:
@ -251,7 +251,7 @@ class ModelingStep:
if output_dir is not None:
non_empirical_dir = Path(output_dir)
else:
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
non_empirical_dir = Path.cwd() / "8_Non_Empirical_Regression"
non_empirical_dir.mkdir(parents=True, exist_ok=True)
if preprocessing_methods is None:
@ -430,7 +430,7 @@ def _apply_preprocessing_internal(
save_path = None
if preprocess_method == "SS":
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training"
models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}")

View File

@ -259,7 +259,7 @@ class PredictionStep:
if non_empirical_models_dir is not None:
final_models_dir = non_empirical_models_dir
else:
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression")
if Path(default_models_dir).exists():
final_models_dir = default_models_dir
else:

View File

@ -138,11 +138,11 @@ class WaterQualityInversionPipeline:
self.water_mask_dir = self.work_dir / "1_water_mask"
self.glint_dir = self.work_dir / "2_Glint_Detection"
self.deglint_dir = self.work_dir / "3_deglint"
self.processed_data_dir = self.work_dir / "4_processed_data"
self.training_spectra_dir = self.work_dir / "5_training_spectra"
self.indices_dir = self.work_dir / "6_water_quality_indices"
self.models_dir = self.work_dir / "7_Supervised_Model_Training"
self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling"
self.processed_data_dir = self.work_dir / "5_Data_Cleaning"
self.training_spectra_dir = self.work_dir / "6_Spectral_Feature_Extraction"
self.indices_dir = self.work_dir / "7_Water_Quality_Indices"
self.models_dir = self.work_dir / "8_Supervised_Model_Training"
self.non_empirical_models_dir = self.work_dir / "8_Non_Empirical_Regression"
self.custom_regression_dir = self.work_dir / "13_Custom_Regression"
self.sampling_dir = self.work_dir / "4_sampling"
self.prediction_dir = self.work_dir / "11_12_13_predictions"
@ -764,7 +764,7 @@ class WaterQualityInversionPipeline:
if not spectrum_csv or not os.path.exists(spectrum_csv):
# 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv
fallback_candidates = []
step5_dir = os.path.join(self.work_dir, "5_Training_Spectra")
step5_dir = os.path.join(self.work_dir, "6_Spectral_Feature_Extraction")
if os.path.isdir(step5_dir):
for f in sorted(os.listdir(step5_dir)):
if f.lower().endswith('.csv'):
@ -2023,10 +2023,10 @@ class WaterQualityInversionPipeline:
# 应用预处理 - 使用spectral_Preprocessing模块
from src.preprocessing.spectral_Preprocessing import Preprocessing
# 为SS预处理提供scaler保存路径保存在工作目录的7_Supervised_Model_Training中
# 为SS预处理提供scaler保存路径保存在工作目录的8_Supervised_Model_Training中
save_path = None
if preprocess_method == 'SS':
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training" # 向上两级到工作目录
models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}")