Step3 插值算法 OOM 修复 + 多进程加速 + 全链路累积改动(14 文件)

This commit is contained in:
DXC
2026-06-15 16:49:17 +08:00
parent 82e0b92af6
commit 60a9d7d922
14 changed files with 855 additions and 152 deletions

View File

@ -3,8 +3,24 @@
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
支持多种插值方法nearest, bilinear, spline (RBF), kriging。
本模块使用多进程并行分块 IO 加速Plan A
- ProcessPoolExecutor 为每个 worker 进程打开一次源影像initializer 阶段),
避免每块重复 gdal.Open 带来的开销Windows 上 ~50ms/次)
- 主进程统一负责输出文件的写入,避免多进程写锁竞争
- 分块大小block_size默认 1024内存充足可调至 2048 / 4096
注意:
- GDAL Dataset / Rasterio Dataset 对象不能跨进程传递picking 不支持),
所以 worker 必须在 init 阶段自己独立打开源文件
- 每个 worker 强制设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 多线程
造成的 CPU 过订阅
- 关闭多进程:传 ``use_multiprocessing=False`` 或 ``n_workers=1``
"""
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import numpy as np
from typing import Optional, Union, Tuple, List
from pathlib import Path
@ -24,6 +40,9 @@ except ImportError:
GDAL_AVAILABLE = False
_worker_dataset: Optional["gdal.Dataset"] = None
def interpolate_pixels(
image_stack: np.ndarray,
zero_coords: np.ndarray,
@ -52,7 +71,6 @@ def interpolate_pixels(
height, width, n_bands = image_stack.shape
result = image_stack.copy()
# 兼容中文和各种格式的method参数
raw_method = str(interpolation_method).lower()
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
method = 'nearest'
@ -181,39 +199,271 @@ def _interpolate_single_band(
return np.zeros(len(zero_coords))
def _normalize_interpolation_method(method: str) -> str:
"""将中文/英文混用的插值方法名归一化为内部标准名
支持: 'nearest'/'邻近'/'最邻近''bilinear'/'线性'/'双线性'
'spline'/'样条'/'rbf''kriging'/'克里金'
"""
raw = str(method).lower()
if 'nearest' in raw or '邻近' in raw or '最邻近' in raw:
return 'nearest'
if 'bilinear' in raw or '线性' in raw or '双线性' in raw:
return 'bilinear'
if 'spline' in raw or '样条' in raw or 'rbf' in raw:
return 'spline'
if 'kriging' in raw or '克里金' in raw:
return 'kriging'
return 'nearest'
def _read_water_mask_to_array(
water_mask: Optional[Union[str, np.ndarray]],
expected_height: int,
expected_width: int,
) -> Optional[np.ndarray]:
"""读取水域掩膜为 numpy 数组单波段bool/int 均可)
None 或空字符串直接返回 None。形状不匹配时给出告警但不抛错
让调用方按"无掩膜"路径继续。
"""
if water_mask is None:
return None
if isinstance(water_mask, str):
if not water_mask.strip():
return None
mask_ds = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_ds is None:
print(f" [warn] 无法打开水域掩膜 {water_mask},按无掩膜处理")
return None
try:
mask_array = mask_ds.GetRasterBand(1).ReadAsArray()
finally:
mask_ds = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
else:
return None
if mask_array.shape != (expected_height, expected_width):
print(
f" [warn] 水域掩膜形状 {mask_array.shape} 与影像 "
f"({expected_height}, {expected_width}) 不匹配,按无掩膜处理"
)
return None
return mask_array
def _init_worker(img_path: str) -> None:
"""ProcessPoolExecutor initializer: 每个 worker 进程只调用一次
在 worker 进程启动时打开源影像 dataset 并缓存在模块全局变量
``_worker_dataset`` 中。后续所有块处理直接复用这个 dataset
避免每块重复 ``gdal.Open``Windows 上约 50ms/次100 块即 5s
同时设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 默认多线程
造成的 CPU 过订阅。
"""
global _worker_dataset
gdal.SetConfigOption('GDAL_NUM_THREADS', '1')
if hasattr(gdal, 'UseExceptions'):
gdal.UseExceptions()
_worker_dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if _worker_dataset is None:
raise RuntimeError(f"Worker failed to open source image: {img_path}")
def _interpolate_block_worker(task: tuple) -> tuple:
"""ProcessPoolExecutor worker: 处理单个块并返回结果
该函数必须保持模块级(可被 pickle不持有任何外部状态——
源 dataset 通过 ``_worker_dataset`` 模块全局变量获取。
Returns:
``(x0, y0, inner_bands, zero_count, error_msg)`` 元组:
- x0, y0: 块在影像中的写入起点
- inner_bands: ``List[np.ndarray]``,每个元素是 (inner_h, inner_w)
float32 数组(每个波段一个),或失败时为 None
- zero_count: 该扩展块中识别到的零像素数(含 halo 范围)
- error_msg: None 表示成功str 表示错误信息
"""
(
x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
) = task
if _worker_dataset is None:
return (x0, y0, None, 0, "Worker dataset not initialized")
try:
inner_bands, zero_count = _process_one_block(
_worker_dataset, x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
)
return (x0, y0, inner_bands, zero_count, None)
except Exception as e:
return (x0, y0, None, 0, str(e))
def _process_one_block(
dataset: "gdal.Dataset",
x0: int, y0: int,
ey0: int, ex0: int, ey1: int, ex1: int,
row_offset: int, col_offset: int,
inner_h: int, inner_w: int,
mask_segment_ext: Optional[np.ndarray],
method: str,
) -> Tuple[List[np.ndarray], int]:
"""处理单个扩展块纯计算核心dataset 显式传入)
串行模式和并行模式共用此函数。并行模式下 dataset 来自 worker 的
缓存(``_worker_dataset``),串行模式下 dataset 由主函数传入。
Args:
dataset: 已打开的源影像 dataset
x0, y0: 内部块左上角(写入位置)
ey0, ex0, ey1, ex1: 扩展块(含 halo坐标
row_offset, col_offset: 内部块在扩展块中的偏移
inner_h, inner_w: 内部块尺寸
mask_segment_ext: 扩展块对应的水域掩膜None 表示不应用)
method: 插值方法(已归一化)
Returns:
``(inner_bands, zero_count)`` 元组:
- inner_bands: ``List[np.ndarray]``,长度 = n_bands每个元素形状为
``(inner_h, inner_w)`` 的 float32 数组
- zero_count: 扩展块中识别到的零像素数
"""
n_bands = dataset.RasterCount
ext_bands: List[np.ndarray] = []
for b in range(1, n_bands + 1):
band = dataset.GetRasterBand(b)
ext_bands.append(
band.ReadAsArray(ex0, ey0, ex1 - ex0, ey1 - ey0).astype(np.float32)
)
band = None
try:
ext_h, ext_w = ey1 - ey0, ex1 - ex0
all_zero_ext = np.ones((ext_h, ext_w), dtype=bool)
for b_data in ext_bands:
all_zero_ext &= (b_data == 0)
if mask_segment_ext is not None:
all_zero_ext &= (mask_segment_ext > 0)
zero_count = int(np.sum(all_zero_ext))
if zero_count == 0:
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, 0
zero_y, zero_x = np.where(all_zero_ext)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_zero_ext
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
print(
f" [warn] 块 (y={y0}-{y0 + inner_h}, x={x0}-{x0 + inner_w}) "
f"无有效像素可作插值上下文,已跳过"
)
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, zero_count
for b in range(n_bands):
ext_band = ext_bands[b]
valid_values_band = ext_band[valid_mask]
if len(valid_values_band) == 0:
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, method
)
ext_band[zero_y, zero_x] = band_result
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, zero_count
finally:
del ext_bands
def interpolate_zero_pixels_batch(
img_path: str,
interpolation_method: str = 'nearest',
output_path: Optional[str] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
deglint_dir: Optional[str] = None,
callback_progress: Optional[callable] = None
callback_progress: Optional[callable] = None,
block_size: int = 1024,
halo_size: int = 64,
n_workers: Optional[int] = None,
use_multiprocessing: bool = True,
) -> Tuple[str, Optional[np.ndarray]]:
"""
对影像中所有波段都为0的像素点进行插值完整流程含文件I/O
对影像中所有波段都为0的像素点进行插值完整流程含文件I/O
采用 **分块 IO + 多进程并行** 策略:
1. 影像按 ``block_size`` × ``block_size`` 分块,每块边界外扩展
``halo_size`` 像素作为插值上下文,避免块边缘插值退化
2. 多进程并行(默认 ``ProcessPoolExecutor``worker 数 = CPU 核心数)
并发处理所有块GDAL Dataset 不能跨进程传递,所以每个 worker
在 ``initializer`` 阶段独立打开源文件一次并缓存
3. 主进程按块序接收处理结果并统一写入输出文件,避免写锁竞争
4. 该方案可彻底避免一次性读取 50 波段整景影像时的 OOM 隐患
50 波段 × 4000×4000 × float32 ≈ 3GB 的 np.dstack
Args:
img_path: 输入影像文件路径
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
output_path: 输出文件路径如果为None自动生成
water_mask: 水域掩膜(文件路径或数组
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline',
'kriging' 及其中文别名('邻近'/'最邻近'/'线性'/'双线性'/'样条'/'克里金'
output_path: 输出文件路径(如果为 None 且 deglint_dir 提供,自动生成
water_mask: 水域掩膜(文件路径或数组),形状须与影像高宽一致
deglint_dir: 去耀斑目录(用于生成默认输出路径)
callback_progress: 进度回调函数
callback_progress: 进度回调函数,签名 ``callback(msg: str)``
block_size: 分块大小(像素),默认 1024内存充足可调 2048/4096
halo_size: 上下文 halo 宽度(像素),默认 64
n_workers: 并行 worker 进程数None = ``multiprocessing.cpu_count()``
传 1 等价于串行模式
use_multiprocessing: 是否启用多进程False 时强制串行
Returns:
(output_path, interpolated_image_stack) 元组
``(output_path, None)`` 元组。第二个值固定为 ``None``(与原版语义保留
兼容;返回完整内存堆叠会重新引入 OOM 风险,故不再提供)。
"""
if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值")
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
# 确定输出路径
if output_path is None and deglint_dir is not None:
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
method = _normalize_interpolation_method(interpolation_method)
# 检查文件是否已存在
if output_path and Path(output_path).exists():
if output_path is None and deglint_dir is not None:
output_path = str(Path(deglint_dir) / f"interpolated_{method}.bsq")
if output_path is None:
raise ValueError("output_path 和 deglint_dir 至少需要指定一个")
if Path(output_path).exists():
return output_path, None
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
@ -227,94 +477,125 @@ def interpolate_zero_pixels_batch(
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
# 读取所有波段数据
all_bands = []
for band_idx in range(1, n_bands + 1):
band = dataset.GetRasterBand(band_idx)
band_data = band.ReadAsArray().astype(np.float32)
all_bands.append(band_data)
image_stack = np.dstack(all_bands)
# 读取水域掩膜
mask_array = None
if water_mask is not None:
if isinstance(water_mask, str):
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
# 找出所有波段都为0的像素点
all_bands_zero = np.all(image_stack == 0, axis=2)
if mask_array is not None:
all_bands_zero = all_bands_zero & (mask_array > 0)
zero_pixel_count = np.sum(all_bands_zero)
if zero_pixel_count == 0:
# 无需插值,直接保存
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(all_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
return output_path, image_stack
# 获取坐标
zero_y, zero_x = np.where(all_bands_zero)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_bands_zero
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
raise ValueError("没有有效像素可用于插值")
# 逐波段插值
interpolated_bands = []
for band_idx in range(n_bands):
if callback_progress:
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
band_data = all_bands[band_idx].copy()
valid_values_band = band_data[valid_mask]
if len(valid_values_band) == 0:
interpolated_bands.append(band_data)
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, interpolation_method
if width <= 0 or height <= 0 or n_bands <= 0:
raise ValueError(
f"影像尺寸异常: width={width}, height={height}, n_bands={n_bands}"
)
band_data[all_bands_zero] = band_result
interpolated_bands.append(band_data)
# 保存结果
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(interpolated_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
mask_array = _read_water_mask_to_array(water_mask, height, width)
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise RuntimeError("未找到可用的栅格驱动ENVI / GTiff 都不存在)")
out_dataset = driver.Create(
output_path, width, height, n_bands, gdal.GDT_Float32
)
if out_dataset is None:
raise RuntimeError(f"无法创建输出文件: {output_path}")
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
try:
if not use_multiprocessing:
effective_workers = 1
elif n_workers is not None and n_workers >= 1:
effective_workers = int(n_workers)
else:
try:
cpu_count = multiprocessing.cpu_count() or 1
except (NotImplementedError, OSError):
cpu_count = 1
effective_workers = max(1, cpu_count)
n_blocks_y = (height + block_size - 1) // block_size
n_blocks_x = (width + block_size - 1) // block_size
total_blocks = n_blocks_y * n_blocks_x
tasks = []
for by in range(n_blocks_y):
y0 = by * block_size
y1 = min(y0 + block_size, height)
inner_h = y1 - y0
ey0 = max(0, y0 - halo_size)
ey1 = min(height, y1 + halo_size)
for bx in range(n_blocks_x):
x0 = bx * block_size
x1 = min(x0 + block_size, width)
inner_w = x1 - x0
ex0 = max(0, x0 - halo_size)
ex1 = min(width, x1 + halo_size)
row_offset = y0 - ey0
col_offset = x0 - ex0
mask_segment_ext = None
if mask_array is not None:
mask_segment_ext = mask_array[ey0:ey1, ex0:ex1]
tasks.append((
x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
))
if callback_progress:
callback_progress(
f"分块插值开始: 共 {total_blocks}"
f"(block_size={block_size}, halo={halo_size}, method={method}, "
f"workers={effective_workers})"
)
total_zero_pixels = 0
if effective_workers <= 1:
for block_idx, task in enumerate(tasks, 1):
x0_t, y0_t = task[0], task[1]
if callback_progress:
callback_progress(
f"{block_idx}/{total_blocks} "
f"y=[{y0_t},{y0_t + task[8]}) x=[{x0_t},{x0_t + task[9]})"
)
inner_bands, zero_count = _process_one_block(
dataset, *task
)
for b_idx, band_data in enumerate(inner_bands):
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
band_data, xoff=x0_t, yoff=y0_t
)
total_zero_pixels += zero_count
else:
with ProcessPoolExecutor(
max_workers=effective_workers,
initializer=_init_worker,
initargs=(img_path,),
) as executor:
futures = [
executor.submit(_interpolate_block_worker, task)
for task in tasks
]
for block_idx, future in enumerate(futures, 1):
x0_t, y0_t, inner_bands, zero_count, error = future.result()
if error is not None:
raise RuntimeError(
f"块 (y={y0_t}, x={x0_t}) 处理失败: {error}"
)
if inner_bands is not None:
for b_idx, band_data in enumerate(inner_bands):
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
band_data, xoff=x0_t, yoff=y0_t
)
total_zero_pixels += zero_count
if callback_progress:
callback_progress(f"已写入块 {block_idx}/{total_blocks}")
if callback_progress:
callback_progress(
f"分块插值完成: 共处理 {total_zero_pixels} 个零像素 "
f"{total_blocks} 块,方法 {method}workers={effective_workers}"
)
return output_path, None
finally:
out_dataset = None
result_stack = np.dstack(interpolated_bands)
return output_path, result_stack
finally:
dataset = None

View File

@ -899,7 +899,7 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
if __name__ == '__main__':
# 在这里直接设置参数
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均

View File

@ -806,8 +806,8 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
if __name__ == '__main__':
# 在这里直接设置参数
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\work_dir\6_Spectral_Feature_Extraction/yangdian_output.csv" # CSV格式输出文件路径
radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径可选None表示不使用

View File

@ -315,7 +315,7 @@ def main():
# 示例1: 使用所有回归方法分析光谱指数
print("\n1. 光谱指数与叶绿素a的回归分析:")
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv")
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\water_quality_results.csv")
spectral_indices = ['Al10SABI','Am092Bsub']
results1 = analyzer.batch_single_variable_regression(
@ -323,7 +323,7 @@ def main():
x_columns=spectral_indices,
y_column='Chlorophyll',
methods='all',
output_file=r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\spectral_indices_regression.csv'
output_file=r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\spectral_indices_regression.csv'
)
# # 示例2: 使用特定方法分析反射率波段
@ -343,7 +343,7 @@ def main():
best_models = analyzer.get_best_models_summary()
if not best_models.empty:
print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False))
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\best_models_summary.csv', index=False)
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\best_models_summary.csv', index=False)
print("\n最佳模型汇总已保存到 'best_models_summary.csv'")
#
# def advanced_usage_example():

View File

@ -246,7 +246,7 @@ def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, outp
if __name__ == "__main__":
algorithm= "chl_a"
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json"
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\8_non_empirical_models\SS\SS_chl_a.json"
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
wave_radius=5.0

View File

@ -98,7 +98,7 @@ PIPELINE_STEPS: List[StepSpec] = [
step_id="step4", method_name="step5_process_csv",
requires=["csv_path"], produces=["processed_csv_path"],
required_input_files=["csv_path"],
output_file="{work_dir}/4_processed_data/processed_data.csv",
output_file="{work_dir}/5_Data_Cleaning/processed_data.csv",
description="CSV 异常值清洗",
),
StepSpec(
@ -111,21 +111,21 @@ PIPELINE_STEPS: List[StepSpec] = [
},
skip_when_missing=False,
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
output_file="{work_dir}/5_training_spectra/training_spectra.csv",
output_file="{work_dir}/6_Spectral_Feature_Extraction/training_spectra.csv",
description="实测样本点光谱提取",
),
StepSpec(
step_id="step7", method_name="step7_calc_indices",
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
required_input_files=["training_csv_path"],
output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv",
output_file="{work_dir}/7_Water_Quality_Indices/training_spectra_indices.csv",
description="水质参数指数计算双轨输出A轨宽表 + B轨单文件",
),
StepSpec(
step_id="step8", method_name="step8_train_ml",
requires=["training_csv_path"], produces=["models_dir"],
required_input_files=["training_csv_path"],
output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl",
output_file="{work_dir}/8_Supervised_Model_Training/best_models.pkl",
description="ML 建模GridSearchCV / AutoML",
),
StepSpec(
@ -134,7 +134,7 @@ PIPELINE_STEPS: List[StepSpec] = [
requires=["training_csv_path"], produces=["models_dir"],
parameter_map={"training_csv_path": "csv_path"},
required_input_files=["training_csv_path"],
output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl",
output_file="{work_dir}/8_Non_Empirical_Regression/non_empirical_models.pkl",
description="非经验统计回归",
),
StepSpec(

View File

@ -328,7 +328,7 @@ def train_with_automl(
split_method = split_methods[0]
if output_dir is None:
output_dir = "./7_Supervised_Model_Training_AutoML"
output_dir = "./8_Supervised_Model_Training_AutoML"
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
preproc_dir = out_dir / preproc
@ -519,7 +519,7 @@ if __name__ == "__main__":
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
p.add_argument("--out", default="./8_Supervised_Model_Training_AutoML")
args = p.parse_args()
# 智能推断 feature_start_column 类型

View File

@ -21,7 +21,7 @@ class DataPreparationStep:
@staticmethod
def process_csv(
csv_path: str,
output_dir: Union[str, Path] = "./4_processed_data",
output_dir: Union[str, Path] = "./5_Data_Cleaning",
callback: Optional[Callable] = None,
) -> str:
"""处理CSV文件筛选剔除异常值"""
@ -61,7 +61,7 @@ class DataPreparationStep:
boundary_path: Optional[str] = None,
glint_mask_path: Optional[str] = None,
water_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./5_training_spectra",
output_dir: Union[str, Path] = "./6_Spectral_Feature_Extraction",
callback: Optional[Callable] = None,
) -> str:
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
@ -131,7 +131,7 @@ class DataPreparationStep:
formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None,
enabled: bool = True,
output_dir: Union[str, Path] = "./6_water_quality_indices",
output_dir: Union[str, Path] = "./7_Water_Quality_Indices",
callback: Optional[Callable] = None,
) -> Optional[str]:
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""

View File

@ -135,7 +135,7 @@ class ModelingStep:
split_methods: Optional[List[str]] = None,
cv_folds: int = 5,
training_csv_path: Optional[str] = None,
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
output_dir: Union[str, Path] = "./8_Supervised_Model_Training",
callback: Optional[Callable] = None,
_report_generator=None,
) -> str:
@ -251,7 +251,7 @@ class ModelingStep:
if output_dir is not None:
non_empirical_dir = Path(output_dir)
else:
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
non_empirical_dir = Path.cwd() / "8_Non_Empirical_Regression"
non_empirical_dir.mkdir(parents=True, exist_ok=True)
if preprocessing_methods is None:
@ -430,7 +430,7 @@ def _apply_preprocessing_internal(
save_path = None
if preprocess_method == "SS":
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training"
models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}")

View File

@ -259,7 +259,7 @@ class PredictionStep:
if non_empirical_models_dir is not None:
final_models_dir = non_empirical_models_dir
else:
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression")
if Path(default_models_dir).exists():
final_models_dir = default_models_dir
else:

View File

@ -138,11 +138,11 @@ class WaterQualityInversionPipeline:
self.water_mask_dir = self.work_dir / "1_water_mask"
self.glint_dir = self.work_dir / "2_Glint_Detection"
self.deglint_dir = self.work_dir / "3_deglint"
self.processed_data_dir = self.work_dir / "4_processed_data"
self.training_spectra_dir = self.work_dir / "5_training_spectra"
self.indices_dir = self.work_dir / "6_water_quality_indices"
self.models_dir = self.work_dir / "7_Supervised_Model_Training"
self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling"
self.processed_data_dir = self.work_dir / "5_Data_Cleaning"
self.training_spectra_dir = self.work_dir / "6_Spectral_Feature_Extraction"
self.indices_dir = self.work_dir / "7_Water_Quality_Indices"
self.models_dir = self.work_dir / "8_Supervised_Model_Training"
self.non_empirical_models_dir = self.work_dir / "8_Non_Empirical_Regression"
self.custom_regression_dir = self.work_dir / "13_Custom_Regression"
self.sampling_dir = self.work_dir / "4_sampling"
self.prediction_dir = self.work_dir / "11_12_13_predictions"
@ -764,7 +764,7 @@ class WaterQualityInversionPipeline:
if not spectrum_csv or not os.path.exists(spectrum_csv):
# 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv
fallback_candidates = []
step5_dir = os.path.join(self.work_dir, "5_Training_Spectra")
step5_dir = os.path.join(self.work_dir, "6_Spectral_Feature_Extraction")
if os.path.isdir(step5_dir):
for f in sorted(os.listdir(step5_dir)):
if f.lower().endswith('.csv'):
@ -2023,10 +2023,10 @@ class WaterQualityInversionPipeline:
# 应用预处理 - 使用spectral_Preprocessing模块
from src.preprocessing.spectral_Preprocessing import Preprocessing
# 为SS预处理提供scaler保存路径保存在工作目录的7_Supervised_Model_Training中
# 为SS预处理提供scaler保存路径保存在工作目录的8_Supervised_Model_Training中
save_path = None
if preprocess_method == 'SS':
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training" # 向上两级到工作目录
models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}")

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step5 面板 - 光谱提取
Step6 面板 - 光谱特征提取
"""
import os
@ -27,7 +27,7 @@ class Step6FeaturePanel(QWidget):
layout = QVBoxLayout()
# 标题
title = QLabel("步骤5训练样本光谱提取")
title = QLabel("步骤6光谱特征提取")
title.setFont(QFont("Arial", 12, QFont.Bold))
layout.addWidget(title)
@ -58,12 +58,12 @@ class Step6FeaturePanel(QWidget):
"Mask Files (*.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.glint_mask_file)
step5_glint_hint = QLabel(
step6_glint_hint = QLabel(
"提示独立运行本步骤时必须选择耀斑掩膜通常为步骤2输出的 severe_glint_area.dat用于在采样时避开耀斑像元。"
)
step5_glint_hint.setWordWrap(True)
step5_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
layout.addWidget(step5_glint_hint)
step6_glint_hint.setWordWrap(True)
step6_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
layout.addWidget(step6_glint_hint)
# 参数设置
params_group = QGroupBox("提取参数")
@ -200,20 +200,22 @@ class Step6FeaturePanel(QWidget):
else:
self.output_file.set_path("")
# 5. 尝试从 Step4 界面读取已处理的水质参数 CSV 路径,自动填入本面板
# 5. 尝试从 Step5 Clean 界面读取已处理的清洗后 CSV 路径,自动填入本面板
main_window = self.window()
if main_window and hasattr(main_window, 'step5_panel'):
step4_output_path = main_window.step5_panel.output_file.get_path()
if step4_output_path:
if main_window and hasattr(main_window, 'step5_clean_panel'):
step5_clean_output_path = main_window.step5_clean_panel.output_file.get_path()
if step5_clean_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step4_output_path):
step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/')
if not os.path.isabs(step5_clean_output_path):
step5_clean_output_path = os.path.join(
self.work_dir or '', step5_clean_output_path
).replace('\\', '/')
existing_csv = self.csv_file.get_path()
if not existing_csv or not existing_csv.strip():
self.csv_file.set_path(step4_output_path)
self.csv_file.set_path(step5_clean_output_path)
def run_step(self):
"""独立运行步骤5"""
"""独立运行步骤6"""
# 验证输入
deglint_img_path = self.deglint_img_file.get_path()
csv_path = self.csv_file.get_path()

View File

@ -1393,9 +1393,7 @@ class WaterQualityGUI(QMainWindow):
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
'water_mask_path': ('step1', 'water_mask', 'water_mask_file')
},
'step5_clean': {
'csv_path': ('step4_sampling', 'sampling_spectra', 'csv_file') # step5 寻找 step4 的采样点
},
# 'step5_clean': 业务要求保持输入源独立,不自动抓取 step4_sampling 的输出;用户手动浏览导入
'step6_feature': {
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
'csv_path': ('step5_clean', 'processed_data', 'csv_file'),
@ -2252,21 +2250,32 @@ class WaterQualityGUI(QMainWindow):
# 检查面板是否有对应的属性
if not hasattr(panel, panel_attr):
continue
file_widget = getattr(panel, panel_attr)
# ★ 兼容 FileSelectWidget 与原生 QLineEdit
current_text = (
file_widget.get_path().strip()
if hasattr(file_widget, 'get_path')
else file_widget.text().strip()
)
# 如果输入框已经有内容,跳过自动填充
if file_widget.get_path().strip():
if current_text:
continue
# 查找依赖步骤的输出文件
output_path = self.find_step_output(work_path, dep_step, output_type)
if output_path and Path(output_path).exists():
file_widget.set_path(output_path)
# ★ 兼容 FileSelectWidget 与原生 QLineEdit
if hasattr(file_widget, 'set_path'):
file_widget.set_path(str(output_path))
else:
file_widget.setText(str(output_path))
self.log_message(f"自动填充 {step_id}.{input_field}: {output_path}", "info")
filled_count += 1
return filled_count
def get_step_panel(self, step_id):

View File

@ -0,0 +1,411 @@
"""
interpolator.py 多进程重构的行为测试mock-based无 osgeo 依赖)
验证:
1. 静态结构模块级函数集合、签名、新参数、_worker_dataset 全局
2. 行为逻辑_process_one_block 在 mock dataset 上的零像素识别 + 插值
3. 行为逻辑_interpolate_block_worker 通过模块全局 _worker_dataset 工作
4. 行为逻辑interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo
5. 向后兼容:现有 8 个 caller 参数保留默认值不变
如果本机有 osgeo可补一个真实数据集的 smoke test。
"""
import os
import sys
import ast
import types
import unittest
from unittest.mock import MagicMock, patch
# Ensure the project src is on path so we can import the module under test
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
INTERP_PATH = os.path.join(
PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py"
)
# =============================================================================
# Static structure tests
# =============================================================================
class TestInterpolatorStructure(unittest.TestCase):
def setUp(self):
with open(INTERP_PATH, "r", encoding="utf-8") as f:
self.src = f.read()
self.tree = ast.parse(self.src)
def test_module_level_functions(self):
mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)}
expected = {
"interpolate_pixels",
"_interpolate_single_band",
"_normalize_interpolation_method",
"_read_water_mask_to_array",
"_init_worker",
"_interpolate_block_worker",
"_process_one_block",
"interpolate_zero_pixels_batch",
}
self.assertEqual(mod_funcs, expected)
self.assertNotIn("_process_block_with_buffer", mod_funcs)
def test_worker_dataset_module_global(self):
mod_globals = set()
for n in self.tree.body:
if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name):
mod_globals.add(n.targets[0].id)
elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name):
mod_globals.add(n.target.id)
self.assertIn("_worker_dataset", mod_globals)
def test_interpolate_zero_pixels_batch_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch":
args = [a.arg for a in n.args.args]
# Backward compat: all 8 existing params + 2 new
self.assertEqual(
args,
[
"img_path", "interpolation_method", "output_path",
"water_mask", "deglint_dir", "callback_progress",
"block_size", "halo_size", "n_workers", "use_multiprocessing",
],
)
defaults = [getattr(d, "value", None) for d in n.args.defaults]
self.assertEqual(
defaults,
["nearest", None, None, None, None, 1024, 64, None, True],
)
return
self.fail("interpolate_zero_pixels_batch not found")
def test_init_worker_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_init_worker":
self.assertEqual([a.arg for a in n.args.args], ["img_path"])
return
self.fail("_init_worker not found")
def test_worker_function_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker":
self.assertEqual([a.arg for a in n.args.args], ["task"])
return
self.fail("_interpolate_block_worker not found")
def test_process_one_block_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block":
args = [a.arg for a in n.args.args]
self.assertEqual(
args,
[
"dataset", "x0", "y0",
"ey0", "ex0", "ey1", "ex1",
"row_offset", "col_offset",
"inner_h", "inner_w",
"mask_segment_ext", "method",
],
)
return
self.fail("_process_one_block not found")
def test_uses_process_pool_executor(self):
self.assertIn("ProcessPoolExecutor", self.src)
self.assertIn("ProcessPoolExecutor(", self.src)
self.assertIn("initializer=_init_worker", self.src)
self.assertIn("initargs=(img_path,)", self.src)
self.assertIn("GDAL_NUM_THREADS", self.src)
def test_dispatch_logic_present(self):
# Both serial and parallel paths should be present
self.assertIn("if effective_workers <= 1:", self.src)
self.assertIn("with ProcessPoolExecutor", self.src)
def test_serial_path_uses_process_one_block(self):
# Serial branch should call _process_one_block directly (no pickle overhead)
self.assertIn(
"_process_one_block(\n dataset, *task\n )",
self.src,
)
def test_worker_path_uses_process_one_block(self):
# Worker function should also call _process_one_block (shared core)
self.assertIn("_process_one_block(", self.src)
# The worker should read from _worker_dataset, not receive dataset in task
self.assertIn("_worker_dataset", self.src)
# =============================================================================
# Mocked behavior tests (no real osgeo/scipy needed)
# =============================================================================
def _make_mock_band(read_data):
band = MagicMock()
band.ReadAsArray.return_value = read_data
return band
def _make_mock_dataset(bands_data, n_bands=None):
if n_bands is None:
n_bands = len(bands_data)
ds = MagicMock()
ds.RasterCount = n_bands
ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1])
return ds
def _make_fake_module(name, attrs=None):
mod = types.ModuleType(name)
for k, v in (attrs or {}).items():
setattr(mod, k, v)
return mod
def _install_fake_modules():
"""Install minimal fakes for scipy + osgeo so the module under test imports."""
import numpy as np
# Fake scipy
scipy = types.ModuleType("scipy")
scipy_ndimage = types.ModuleType("scipy.ndimage")
scipy_interp = types.ModuleType("scipy.interpolate")
scipy_spatial = types.ModuleType("scipy.spatial")
class _FakeTree:
def __init__(self, coords):
self.coords = coords
def query(self, pts):
# Return index 0 for all queries
import numpy as np
n = len(pts)
return np.zeros(n, dtype=int), np.zeros(n, dtype=int)
def _griddata(coords, values, pts, method="linear", fill_value=0.0):
# Trivial: nearest valid value
if len(coords) == 0 or len(pts) == 0:
return np.zeros(len(pts), dtype=np.float32)
# Just return first value for all queries (testing only)
return np.full(len(pts), values[0], dtype=np.float32)
class _FakeRBF:
def __init__(self, coords, values, kernel=None):
self.first_value = values[0]
def __call__(self, pts):
return np.full(len(pts), self.first_value, dtype=np.float32)
scipy_spatial.cKDTree = _FakeTree
scipy_interp.griddata = _griddata
scipy_interp.RBFInterpolator = _FakeRBF
scipy.ndimage = scipy_ndimage
scipy.interpolate = scipy_interp
scipy.spatial = scipy_spatial
# Fake osgeo.gdal with the constants the module needs
osgeo = types.ModuleType("osgeo")
gdal_mod = types.ModuleType("osgeo.gdal")
gdal_mod.GA_ReadOnly = 0
gdal_mod.GDT_Float32 = 6
gdal_mod.UseExceptions = MagicMock()
def _open(path, mode):
return _make_mock_dataset([]) # empty dataset; real tests build their own
gdal_mod.Open = _open
gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock())
gdal_mod.SetConfigOption = MagicMock()
osgeo.gdal = gdal_mod
sys.modules["scipy"] = scipy
sys.modules["scipy.ndimage"] = scipy_ndimage
sys.modules["scipy.interpolate"] = scipy_interp
sys.modules["scipy.spatial"] = scipy_spatial
sys.modules["osgeo"] = osgeo
sys.modules["osgeo.gdal"] = gdal_mod
return {
"scipy": scipy,
"scipy.spatial": scipy_spatial,
"scipy.interpolate": scipy_interp,
"osgeo.gdal": gdal_mod,
}
class TestProcessOneBlockMocked(unittest.TestCase):
"""Verify _process_one_block logic with mocked GDAL/scipy."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
# Import the module under test after installing fakes
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
cls.gdal = sys.modules["osgeo.gdal"]
def _build_dataset(self, bands):
"""Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h)
returns bands[b-1] (a numpy array of the requested shape).
For simplicity we return the full band array regardless of x,y,w,h.
"""
ds = MagicMock()
ds.RasterCount = len(bands)
def get_band(b):
band = MagicMock()
band.ReadAsArray.return_value = bands[b - 1]
return band
ds.GetRasterBand.side_effect = get_band
return ds
def test_no_zeros_returns_inner_blocks_unchanged(self):
"""If no pixels are all-zero, inner blocks should be returned as-is."""
import numpy as np
# 2x2 image, 3 bands, all 1s (no zeros anywhere)
band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo
bands = [band, band * 2, band * 3]
ds = self._build_dataset(bands)
inner_bands, zero_count = self.interp._process_one_block(
dataset=ds,
x0=0, y0=0,
ey0=0, ex0=0, ey1=4, ex1=4,
row_offset=0, col_offset=0,
inner_h=2, inner_w=2,
mask_segment_ext=None,
method="nearest",
)
self.assertEqual(zero_count, 0)
self.assertEqual(len(inner_bands), 3)
for ib, expected_band in zip(inner_bands, bands):
self.assertEqual(ib.shape, (2, 2))
np.testing.assert_array_equal(ib, expected_band[:2, :2])
def test_with_zero_pixel_triggers_interpolation(self):
"""If a pixel is all-zero, _interpolate_single_band should be called."""
import numpy as np
# 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s.
band1 = np.array([
[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
], dtype=np.float32)
band2 = band1 * 2
band3 = band1 * 3
ds = self._build_dataset([band1, band2, band3])
# Process the top-left 2x2 block (with halo, 4x4 covers all)
inner_bands, zero_count = self.interp._process_one_block(
dataset=ds,
x0=0, y0=0,
ey0=0, ex0=0, ey1=4, ex1=4,
row_offset=0, col_offset=0,
inner_h=2, inner_w=2,
mask_segment_ext=None,
method="nearest",
)
self.assertGreater(zero_count, 0, "should detect zero pixels")
self.assertEqual(len(inner_bands), 3)
# Each inner band should be 2x2
for ib in inner_bands:
self.assertEqual(ib.shape, (2, 2))
# Our fake nearest interpolation returns valid_values[0]=1 for band1
# So the inner block should be filled with non-zero values
self.assertTrue(np.all(inner_bands[0] > 0))
class TestWorkerFunctionMocked(unittest.TestCase):
"""Verify _interpolate_block_worker uses module-global _worker_dataset."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
def test_worker_uses_module_global_dataset(self):
import numpy as np
# Set the module global
band = np.array([
[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
], dtype=np.float32)
ds = self._build_dataset([band, band * 2])
self.interp._worker_dataset = ds
try:
task = (
0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1
0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w
None, "nearest", # mask_segment_ext, method
)
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
self.assertIsNone(error)
self.assertEqual((x0, y0), (0, 0))
self.assertGreater(zero_count, 0)
self.assertIsNotNone(inner_bands)
self.assertEqual(len(inner_bands), 2)
finally:
self.interp._worker_dataset = None
def test_worker_returns_error_if_dataset_uninitialized(self):
self.interp._worker_dataset = None
task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest")
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
self.assertIsNotNone(error)
self.assertIn("not initialized", error)
self.assertIsNone(inner_bands)
def _build_dataset(self, bands):
ds = MagicMock()
ds.RasterCount = len(bands)
def get_band(b):
band = MagicMock()
band.ReadAsArray.return_value = bands[b - 1]
return band
ds.GetRasterBand.side_effect = get_band
return ds
class TestInitWorkerMocked(unittest.TestCase):
"""Verify _init_worker sets module global and config option."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
cls.gdal_mod = sys.modules["osgeo.gdal"]
def setUp(self):
self.interp._worker_dataset = None
def tearDown(self):
self.interp._worker_dataset = None
def test_init_worker_opens_and_caches(self):
fake_ds = MagicMock()
with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \
patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock:
self.interp._init_worker("/fake/path.bsq")
self.assertIs(self.interp._worker_dataset, fake_ds)
open_mock.assert_called_once_with("/fake/path.bsq", 0)
# Should have set GDAL_NUM_THREADS=1
cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1")
def test_init_worker_raises_if_open_fails(self):
with patch.object(self.gdal_mod, "Open", return_value=None):
with self.assertRaises(RuntimeError):
self.interp._init_worker("/bad/path.bsq")
if __name__ == "__main__":
unittest.main(verbosity=2)