Step3 插值算法 OOM 修复 + 多进程加速 + 全链路累积改动(14 文件)

This commit is contained in:
DXC
2026-06-15 16:49:17 +08:00
parent 82e0b92af6
commit 60a9d7d922
14 changed files with 855 additions and 152 deletions

View File

@ -3,8 +3,24 @@
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。 提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
支持多种插值方法nearest, bilinear, spline (RBF), kriging。 支持多种插值方法nearest, bilinear, spline (RBF), kriging。
本模块使用多进程并行分块 IO 加速Plan A
- ProcessPoolExecutor 为每个 worker 进程打开一次源影像initializer 阶段),
避免每块重复 gdal.Open 带来的开销Windows 上 ~50ms/次)
- 主进程统一负责输出文件的写入,避免多进程写锁竞争
- 分块大小block_size默认 1024内存充足可调至 2048 / 4096
注意:
- GDAL Dataset / Rasterio Dataset 对象不能跨进程传递picking 不支持),
所以 worker 必须在 init 阶段自己独立打开源文件
- 每个 worker 强制设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 多线程
造成的 CPU 过订阅
- 关闭多进程:传 ``use_multiprocessing=False`` 或 ``n_workers=1``
""" """
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import numpy as np import numpy as np
from typing import Optional, Union, Tuple, List from typing import Optional, Union, Tuple, List
from pathlib import Path from pathlib import Path
@ -24,6 +40,9 @@ except ImportError:
GDAL_AVAILABLE = False GDAL_AVAILABLE = False
_worker_dataset: Optional["gdal.Dataset"] = None
def interpolate_pixels( def interpolate_pixels(
image_stack: np.ndarray, image_stack: np.ndarray,
zero_coords: np.ndarray, zero_coords: np.ndarray,
@ -52,7 +71,6 @@ def interpolate_pixels(
height, width, n_bands = image_stack.shape height, width, n_bands = image_stack.shape
result = image_stack.copy() result = image_stack.copy()
# 兼容中文和各种格式的method参数
raw_method = str(interpolation_method).lower() raw_method = str(interpolation_method).lower()
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method: if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
method = 'nearest' method = 'nearest'
@ -181,39 +199,271 @@ def _interpolate_single_band(
return np.zeros(len(zero_coords)) return np.zeros(len(zero_coords))
def _normalize_interpolation_method(method: str) -> str:
"""将中文/英文混用的插值方法名归一化为内部标准名
支持: 'nearest'/'邻近'/'最邻近''bilinear'/'线性'/'双线性'
'spline'/'样条'/'rbf''kriging'/'克里金'
"""
raw = str(method).lower()
if 'nearest' in raw or '邻近' in raw or '最邻近' in raw:
return 'nearest'
if 'bilinear' in raw or '线性' in raw or '双线性' in raw:
return 'bilinear'
if 'spline' in raw or '样条' in raw or 'rbf' in raw:
return 'spline'
if 'kriging' in raw or '克里金' in raw:
return 'kriging'
return 'nearest'
def _read_water_mask_to_array(
water_mask: Optional[Union[str, np.ndarray]],
expected_height: int,
expected_width: int,
) -> Optional[np.ndarray]:
"""读取水域掩膜为 numpy 数组单波段bool/int 均可)
None 或空字符串直接返回 None。形状不匹配时给出告警但不抛错
让调用方按"无掩膜"路径继续。
"""
if water_mask is None:
return None
if isinstance(water_mask, str):
if not water_mask.strip():
return None
mask_ds = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_ds is None:
print(f" [warn] 无法打开水域掩膜 {water_mask},按无掩膜处理")
return None
try:
mask_array = mask_ds.GetRasterBand(1).ReadAsArray()
finally:
mask_ds = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
else:
return None
if mask_array.shape != (expected_height, expected_width):
print(
f" [warn] 水域掩膜形状 {mask_array.shape} 与影像 "
f"({expected_height}, {expected_width}) 不匹配,按无掩膜处理"
)
return None
return mask_array
def _init_worker(img_path: str) -> None:
"""ProcessPoolExecutor initializer: 每个 worker 进程只调用一次
在 worker 进程启动时打开源影像 dataset 并缓存在模块全局变量
``_worker_dataset`` 中。后续所有块处理直接复用这个 dataset
避免每块重复 ``gdal.Open``Windows 上约 50ms/次100 块即 5s
同时设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 默认多线程
造成的 CPU 过订阅。
"""
global _worker_dataset
gdal.SetConfigOption('GDAL_NUM_THREADS', '1')
if hasattr(gdal, 'UseExceptions'):
gdal.UseExceptions()
_worker_dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if _worker_dataset is None:
raise RuntimeError(f"Worker failed to open source image: {img_path}")
def _interpolate_block_worker(task: tuple) -> tuple:
"""ProcessPoolExecutor worker: 处理单个块并返回结果
该函数必须保持模块级(可被 pickle不持有任何外部状态——
源 dataset 通过 ``_worker_dataset`` 模块全局变量获取。
Returns:
``(x0, y0, inner_bands, zero_count, error_msg)`` 元组:
- x0, y0: 块在影像中的写入起点
- inner_bands: ``List[np.ndarray]``,每个元素是 (inner_h, inner_w)
float32 数组(每个波段一个),或失败时为 None
- zero_count: 该扩展块中识别到的零像素数(含 halo 范围)
- error_msg: None 表示成功str 表示错误信息
"""
(
x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
) = task
if _worker_dataset is None:
return (x0, y0, None, 0, "Worker dataset not initialized")
try:
inner_bands, zero_count = _process_one_block(
_worker_dataset, x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
)
return (x0, y0, inner_bands, zero_count, None)
except Exception as e:
return (x0, y0, None, 0, str(e))
def _process_one_block(
dataset: "gdal.Dataset",
x0: int, y0: int,
ey0: int, ex0: int, ey1: int, ex1: int,
row_offset: int, col_offset: int,
inner_h: int, inner_w: int,
mask_segment_ext: Optional[np.ndarray],
method: str,
) -> Tuple[List[np.ndarray], int]:
"""处理单个扩展块纯计算核心dataset 显式传入)
串行模式和并行模式共用此函数。并行模式下 dataset 来自 worker 的
缓存(``_worker_dataset``),串行模式下 dataset 由主函数传入。
Args:
dataset: 已打开的源影像 dataset
x0, y0: 内部块左上角(写入位置)
ey0, ex0, ey1, ex1: 扩展块(含 halo坐标
row_offset, col_offset: 内部块在扩展块中的偏移
inner_h, inner_w: 内部块尺寸
mask_segment_ext: 扩展块对应的水域掩膜None 表示不应用)
method: 插值方法(已归一化)
Returns:
``(inner_bands, zero_count)`` 元组:
- inner_bands: ``List[np.ndarray]``,长度 = n_bands每个元素形状为
``(inner_h, inner_w)`` 的 float32 数组
- zero_count: 扩展块中识别到的零像素数
"""
n_bands = dataset.RasterCount
ext_bands: List[np.ndarray] = []
for b in range(1, n_bands + 1):
band = dataset.GetRasterBand(b)
ext_bands.append(
band.ReadAsArray(ex0, ey0, ex1 - ex0, ey1 - ey0).astype(np.float32)
)
band = None
try:
ext_h, ext_w = ey1 - ey0, ex1 - ex0
all_zero_ext = np.ones((ext_h, ext_w), dtype=bool)
for b_data in ext_bands:
all_zero_ext &= (b_data == 0)
if mask_segment_ext is not None:
all_zero_ext &= (mask_segment_ext > 0)
zero_count = int(np.sum(all_zero_ext))
if zero_count == 0:
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, 0
zero_y, zero_x = np.where(all_zero_ext)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_zero_ext
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
print(
f" [warn] 块 (y={y0}-{y0 + inner_h}, x={x0}-{x0 + inner_w}) "
f"无有效像素可作插值上下文,已跳过"
)
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, zero_count
for b in range(n_bands):
ext_band = ext_bands[b]
valid_values_band = ext_band[valid_mask]
if len(valid_values_band) == 0:
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, method
)
ext_band[zero_y, zero_x] = band_result
inner_bands = [
ext_bands[b][
row_offset:row_offset + inner_h,
col_offset:col_offset + inner_w,
]
for b in range(n_bands)
]
return inner_bands, zero_count
finally:
del ext_bands
def interpolate_zero_pixels_batch( def interpolate_zero_pixels_batch(
img_path: str, img_path: str,
interpolation_method: str = 'nearest', interpolation_method: str = 'nearest',
output_path: Optional[str] = None, output_path: Optional[str] = None,
water_mask: Optional[Union[str, np.ndarray]] = None, water_mask: Optional[Union[str, np.ndarray]] = None,
deglint_dir: Optional[str] = None, deglint_dir: Optional[str] = None,
callback_progress: Optional[callable] = None callback_progress: Optional[callable] = None,
block_size: int = 1024,
halo_size: int = 64,
n_workers: Optional[int] = None,
use_multiprocessing: bool = True,
) -> Tuple[str, Optional[np.ndarray]]: ) -> Tuple[str, Optional[np.ndarray]]:
""" """
对影像中所有波段都为0的像素点进行插值完整流程含文件I/O 对影像中所有波段都为0的像素点进行插值完整流程含文件I/O
采用 **分块 IO + 多进程并行** 策略:
1. 影像按 ``block_size`` × ``block_size`` 分块,每块边界外扩展
``halo_size`` 像素作为插值上下文,避免块边缘插值退化
2. 多进程并行(默认 ``ProcessPoolExecutor``worker 数 = CPU 核心数)
并发处理所有块GDAL Dataset 不能跨进程传递,所以每个 worker
在 ``initializer`` 阶段独立打开源文件一次并缓存
3. 主进程按块序接收处理结果并统一写入输出文件,避免写锁竞争
4. 该方案可彻底避免一次性读取 50 波段整景影像时的 OOM 隐患
50 波段 × 4000×4000 × float32 ≈ 3GB 的 np.dstack
Args: Args:
img_path: 输入影像文件路径 img_path: 输入影像文件路径
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging' interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline',
output_path: 输出文件路径如果为None自动生成 'kriging' 及其中文别名('邻近'/'最邻近'/'线性'/'双线性'/'样条'/'克里金'
water_mask: 水域掩膜(文件路径或数组 output_path: 输出文件路径(如果为 None 且 deglint_dir 提供,自动生成
water_mask: 水域掩膜(文件路径或数组),形状须与影像高宽一致
deglint_dir: 去耀斑目录(用于生成默认输出路径) deglint_dir: 去耀斑目录(用于生成默认输出路径)
callback_progress: 进度回调函数 callback_progress: 进度回调函数,签名 ``callback(msg: str)``
block_size: 分块大小(像素),默认 1024内存充足可调 2048/4096
halo_size: 上下文 halo 宽度(像素),默认 64
n_workers: 并行 worker 进程数None = ``multiprocessing.cpu_count()``
传 1 等价于串行模式
use_multiprocessing: 是否启用多进程False 时强制串行
Returns: Returns:
(output_path, interpolated_image_stack) 元组 ``(output_path, None)`` 元组。第二个值固定为 ``None``(与原版语义保留
兼容;返回完整内存堆叠会重新引入 OOM 风险,故不再提供)。
""" """
if not SCIPY_AVAILABLE: if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值") raise ImportError("scipy未安装无法进行0值像素插值")
if not GDAL_AVAILABLE: if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件") raise ImportError("GDAL未安装无法读取影像文件")
# 确定输出路径 method = _normalize_interpolation_method(interpolation_method)
if output_path is None and deglint_dir is not None:
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
# 检查文件是否已存在 if output_path is None and deglint_dir is not None:
if output_path and Path(output_path).exists(): output_path = str(Path(deglint_dir) / f"interpolated_{method}.bsq")
if output_path is None:
raise ValueError("output_path 和 deglint_dir 至少需要指定一个")
if Path(output_path).exists():
return output_path, None return output_path, None
dataset = gdal.Open(img_path, gdal.GA_ReadOnly) dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
@ -227,94 +477,125 @@ def interpolate_zero_pixels_batch(
geotransform = dataset.GetGeoTransform() geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection() projection = dataset.GetProjection()
# 读取所有波段数据 if width <= 0 or height <= 0 or n_bands <= 0:
all_bands = [] raise ValueError(
for band_idx in range(1, n_bands + 1): f"影像尺寸异常: width={width}, height={height}, n_bands={n_bands}"
band = dataset.GetRasterBand(band_idx)
band_data = band.ReadAsArray().astype(np.float32)
all_bands.append(band_data)
image_stack = np.dstack(all_bands)
# 读取水域掩膜
mask_array = None
if water_mask is not None:
if isinstance(water_mask, str):
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
# 找出所有波段都为0的像素点
all_bands_zero = np.all(image_stack == 0, axis=2)
if mask_array is not None:
all_bands_zero = all_bands_zero & (mask_array > 0)
zero_pixel_count = np.sum(all_bands_zero)
if zero_pixel_count == 0:
# 无需插值,直接保存
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(all_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
return output_path, image_stack
# 获取坐标
zero_y, zero_x = np.where(all_bands_zero)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_bands_zero
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
raise ValueError("没有有效像素可用于插值")
# 逐波段插值
interpolated_bands = []
for band_idx in range(n_bands):
if callback_progress:
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
band_data = all_bands[band_idx].copy()
valid_values_band = band_data[valid_mask]
if len(valid_values_band) == 0:
interpolated_bands.append(band_data)
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, interpolation_method
) )
band_data[all_bands_zero] = band_result
interpolated_bands.append(band_data)
# 保存结果 mask_array = _read_water_mask_to_array(water_mask, height, width)
if output_path:
driver = gdal.GetDriverByName('ENVI') driver = gdal.GetDriverByName('ENVI')
if driver is None: if driver is None:
driver = gdal.GetDriverByName('GTiff') driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32) if driver is None:
raise RuntimeError("未找到可用的栅格驱动ENVI / GTiff 都不存在)")
out_dataset = driver.Create(
output_path, width, height, n_bands, gdal.GDT_Float32
)
if out_dataset is None:
raise RuntimeError(f"无法创建输出文件: {output_path}")
out_dataset.SetGeoTransform(geotransform) out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection) out_dataset.SetProjection(projection)
for i, band_data in enumerate(interpolated_bands):
out_band = out_dataset.GetRasterBand(i + 1) try:
out_band.WriteArray(band_data) if not use_multiprocessing:
out_band.FlushCache() effective_workers = 1
elif n_workers is not None and n_workers >= 1:
effective_workers = int(n_workers)
else:
try:
cpu_count = multiprocessing.cpu_count() or 1
except (NotImplementedError, OSError):
cpu_count = 1
effective_workers = max(1, cpu_count)
n_blocks_y = (height + block_size - 1) // block_size
n_blocks_x = (width + block_size - 1) // block_size
total_blocks = n_blocks_y * n_blocks_x
tasks = []
for by in range(n_blocks_y):
y0 = by * block_size
y1 = min(y0 + block_size, height)
inner_h = y1 - y0
ey0 = max(0, y0 - halo_size)
ey1 = min(height, y1 + halo_size)
for bx in range(n_blocks_x):
x0 = bx * block_size
x1 = min(x0 + block_size, width)
inner_w = x1 - x0
ex0 = max(0, x0 - halo_size)
ex1 = min(width, x1 + halo_size)
row_offset = y0 - ey0
col_offset = x0 - ex0
mask_segment_ext = None
if mask_array is not None:
mask_segment_ext = mask_array[ey0:ey1, ex0:ex1]
tasks.append((
x0, y0, ey0, ex0, ey1, ex1,
row_offset, col_offset, inner_h, inner_w,
mask_segment_ext, method,
))
if callback_progress:
callback_progress(
f"分块插值开始: 共 {total_blocks}"
f"(block_size={block_size}, halo={halo_size}, method={method}, "
f"workers={effective_workers})"
)
total_zero_pixels = 0
if effective_workers <= 1:
for block_idx, task in enumerate(tasks, 1):
x0_t, y0_t = task[0], task[1]
if callback_progress:
callback_progress(
f"{block_idx}/{total_blocks} "
f"y=[{y0_t},{y0_t + task[8]}) x=[{x0_t},{x0_t + task[9]})"
)
inner_bands, zero_count = _process_one_block(
dataset, *task
)
for b_idx, band_data in enumerate(inner_bands):
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
band_data, xoff=x0_t, yoff=y0_t
)
total_zero_pixels += zero_count
else:
with ProcessPoolExecutor(
max_workers=effective_workers,
initializer=_init_worker,
initargs=(img_path,),
) as executor:
futures = [
executor.submit(_interpolate_block_worker, task)
for task in tasks
]
for block_idx, future in enumerate(futures, 1):
x0_t, y0_t, inner_bands, zero_count, error = future.result()
if error is not None:
raise RuntimeError(
f"块 (y={y0_t}, x={x0_t}) 处理失败: {error}"
)
if inner_bands is not None:
for b_idx, band_data in enumerate(inner_bands):
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
band_data, xoff=x0_t, yoff=y0_t
)
total_zero_pixels += zero_count
if callback_progress:
callback_progress(f"已写入块 {block_idx}/{total_blocks}")
if callback_progress:
callback_progress(
f"分块插值完成: 共处理 {total_zero_pixels} 个零像素 "
f"{total_blocks} 块,方法 {method}workers={effective_workers}"
)
return output_path, None
finally:
out_dataset = None out_dataset = None
result_stack = np.dstack(interpolated_bands)
return output_path, result_stack
finally: finally:
dataset = None dataset = None

View File

@ -899,7 +899,7 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
if __name__ == '__main__': if __name__ == '__main__':
# 在这里直接设置参数 # 在这里直接设置参数
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径 imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度 coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径 output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均 radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均

View File

@ -806,8 +806,8 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
if __name__ == '__main__': if __name__ == '__main__':
# 在这里直接设置参数 # 在这里直接设置参数
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径 imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度 coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径 output_path = r"E:\code\WQ\封装\work_dir\6_Spectral_Feature_Extraction/yangdian_output.csv" # CSV格式输出文件路径
radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均 radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径可选None表示不使用 flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径可选None表示不使用

View File

@ -315,7 +315,7 @@ def main():
# 示例1: 使用所有回归方法分析光谱指数 # 示例1: 使用所有回归方法分析光谱指数
print("\n1. 光谱指数与叶绿素a的回归分析:") print("\n1. 光谱指数与叶绿素a的回归分析:")
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv") sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\water_quality_results.csv")
spectral_indices = ['Al10SABI','Am092Bsub'] spectral_indices = ['Al10SABI','Am092Bsub']
results1 = analyzer.batch_single_variable_regression( results1 = analyzer.batch_single_variable_regression(
@ -323,7 +323,7 @@ def main():
x_columns=spectral_indices, x_columns=spectral_indices,
y_column='Chlorophyll', y_column='Chlorophyll',
methods='all', methods='all',
output_file=r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\spectral_indices_regression.csv' output_file=r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\spectral_indices_regression.csv'
) )
# # 示例2: 使用特定方法分析反射率波段 # # 示例2: 使用特定方法分析反射率波段
@ -343,7 +343,7 @@ def main():
best_models = analyzer.get_best_models_summary() best_models = analyzer.get_best_models_summary()
if not best_models.empty: if not best_models.empty:
print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False)) print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False))
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\best_models_summary.csv', index=False) best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\best_models_summary.csv', index=False)
print("\n最佳模型汇总已保存到 'best_models_summary.csv'") print("\n最佳模型汇总已保存到 'best_models_summary.csv'")
# #
# def advanced_usage_example(): # def advanced_usage_example():

View File

@ -246,7 +246,7 @@ def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, outp
if __name__ == "__main__": if __name__ == "__main__":
algorithm= "chl_a" algorithm= "chl_a"
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json" model_info_path= r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\8_non_empirical_models\SS\SS_chl_a.json"
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv" coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv" output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
wave_radius=5.0 wave_radius=5.0

View File

@ -98,7 +98,7 @@ PIPELINE_STEPS: List[StepSpec] = [
step_id="step4", method_name="step5_process_csv", step_id="step4", method_name="step5_process_csv",
requires=["csv_path"], produces=["processed_csv_path"], requires=["csv_path"], produces=["processed_csv_path"],
required_input_files=["csv_path"], required_input_files=["csv_path"],
output_file="{work_dir}/4_processed_data/processed_data.csv", output_file="{work_dir}/5_Data_Cleaning/processed_data.csv",
description="CSV 异常值清洗", description="CSV 异常值清洗",
), ),
StepSpec( StepSpec(
@ -111,21 +111,21 @@ PIPELINE_STEPS: List[StepSpec] = [
}, },
skip_when_missing=False, skip_when_missing=False,
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"], required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
output_file="{work_dir}/5_training_spectra/training_spectra.csv", output_file="{work_dir}/6_Spectral_Feature_Extraction/training_spectra.csv",
description="实测样本点光谱提取", description="实测样本点光谱提取",
), ),
StepSpec( StepSpec(
step_id="step7", method_name="step7_calc_indices", step_id="step7", method_name="step7_calc_indices",
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"], requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
required_input_files=["training_csv_path"], required_input_files=["training_csv_path"],
output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv", output_file="{work_dir}/7_Water_Quality_Indices/training_spectra_indices.csv",
description="水质参数指数计算双轨输出A轨宽表 + B轨单文件", description="水质参数指数计算双轨输出A轨宽表 + B轨单文件",
), ),
StepSpec( StepSpec(
step_id="step8", method_name="step8_train_ml", step_id="step8", method_name="step8_train_ml",
requires=["training_csv_path"], produces=["models_dir"], requires=["training_csv_path"], produces=["models_dir"],
required_input_files=["training_csv_path"], required_input_files=["training_csv_path"],
output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl", output_file="{work_dir}/8_Supervised_Model_Training/best_models.pkl",
description="ML 建模GridSearchCV / AutoML", description="ML 建模GridSearchCV / AutoML",
), ),
StepSpec( StepSpec(
@ -134,7 +134,7 @@ PIPELINE_STEPS: List[StepSpec] = [
requires=["training_csv_path"], produces=["models_dir"], requires=["training_csv_path"], produces=["models_dir"],
parameter_map={"training_csv_path": "csv_path"}, parameter_map={"training_csv_path": "csv_path"},
required_input_files=["training_csv_path"], required_input_files=["training_csv_path"],
output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl", output_file="{work_dir}/8_Non_Empirical_Regression/non_empirical_models.pkl",
description="非经验统计回归", description="非经验统计回归",
), ),
StepSpec( StepSpec(

View File

@ -328,7 +328,7 @@ def train_with_automl(
split_method = split_methods[0] split_method = split_methods[0]
if output_dir is None: if output_dir is None:
output_dir = "./7_Supervised_Model_Training_AutoML" output_dir = "./8_Supervised_Model_Training_AutoML"
out_dir = Path(output_dir) out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
preproc_dir = out_dir / preproc preproc_dir = out_dir / preproc
@ -519,7 +519,7 @@ if __name__ == "__main__":
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS) p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT) p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES) p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML") p.add_argument("--out", default="./8_Supervised_Model_Training_AutoML")
args = p.parse_args() args = p.parse_args()
# 智能推断 feature_start_column 类型 # 智能推断 feature_start_column 类型

View File

@ -21,7 +21,7 @@ class DataPreparationStep:
@staticmethod @staticmethod
def process_csv( def process_csv(
csv_path: str, csv_path: str,
output_dir: Union[str, Path] = "./4_processed_data", output_dir: Union[str, Path] = "./5_Data_Cleaning",
callback: Optional[Callable] = None, callback: Optional[Callable] = None,
) -> str: ) -> str:
"""处理CSV文件筛选剔除异常值""" """处理CSV文件筛选剔除异常值"""
@ -61,7 +61,7 @@ class DataPreparationStep:
boundary_path: Optional[str] = None, boundary_path: Optional[str] = None,
glint_mask_path: Optional[str] = None, glint_mask_path: Optional[str] = None,
water_mask_path: Optional[str] = None, water_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./5_training_spectra", output_dir: Union[str, Path] = "./6_Spectral_Feature_Extraction",
callback: Optional[Callable] = None, callback: Optional[Callable] = None,
) -> str: ) -> str:
"""根据采样点坐标在去耀斑影像中提取平均光谱""" """根据采样点坐标在去耀斑影像中提取平均光谱"""
@ -131,7 +131,7 @@ class DataPreparationStep:
formula_names: Optional[List[str]] = None, formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None, output_file: Optional[str] = None,
enabled: bool = True, enabled: bool = True,
output_dir: Union[str, Path] = "./6_water_quality_indices", output_dir: Union[str, Path] = "./7_Water_Quality_Indices",
callback: Optional[Callable] = None, callback: Optional[Callable] = None,
) -> Optional[str]: ) -> Optional[str]:
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)""" """根据训练光谱计算水质光谱指数(使用 band_math 方法)"""

View File

@ -135,7 +135,7 @@ class ModelingStep:
split_methods: Optional[List[str]] = None, split_methods: Optional[List[str]] = None,
cv_folds: int = 5, cv_folds: int = 5,
training_csv_path: Optional[str] = None, training_csv_path: Optional[str] = None,
output_dir: Union[str, Path] = "./7_Supervised_Model_Training", output_dir: Union[str, Path] = "./8_Supervised_Model_Training",
callback: Optional[Callable] = None, callback: Optional[Callable] = None,
_report_generator=None, _report_generator=None,
) -> str: ) -> str:
@ -251,7 +251,7 @@ class ModelingStep:
if output_dir is not None: if output_dir is not None:
non_empirical_dir = Path(output_dir) non_empirical_dir = Path(output_dir)
else: else:
non_empirical_dir = Path.cwd() / "8_Regression_Modeling" non_empirical_dir = Path.cwd() / "8_Non_Empirical_Regression"
non_empirical_dir.mkdir(parents=True, exist_ok=True) non_empirical_dir.mkdir(parents=True, exist_ok=True)
if preprocessing_methods is None: if preprocessing_methods is None:
@ -430,7 +430,7 @@ def _apply_preprocessing_internal(
save_path = None save_path = None
if preprocess_method == "SS": if preprocess_method == "SS":
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" models_dir = output_dir.parent.parent / "8_Supervised_Model_Training"
models_dir.mkdir(parents=True, exist_ok=True) models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl") save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}") print(f"SS预处理: scaler模型将保存到 {save_path}")

View File

@ -259,7 +259,7 @@ class PredictionStep:
if non_empirical_models_dir is not None: if non_empirical_models_dir is not None:
final_models_dir = non_empirical_models_dir final_models_dir = non_empirical_models_dir
else: else:
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling") default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression")
if Path(default_models_dir).exists(): if Path(default_models_dir).exists():
final_models_dir = default_models_dir final_models_dir = default_models_dir
else: else:

View File

@ -138,11 +138,11 @@ class WaterQualityInversionPipeline:
self.water_mask_dir = self.work_dir / "1_water_mask" self.water_mask_dir = self.work_dir / "1_water_mask"
self.glint_dir = self.work_dir / "2_Glint_Detection" self.glint_dir = self.work_dir / "2_Glint_Detection"
self.deglint_dir = self.work_dir / "3_deglint" self.deglint_dir = self.work_dir / "3_deglint"
self.processed_data_dir = self.work_dir / "4_processed_data" self.processed_data_dir = self.work_dir / "5_Data_Cleaning"
self.training_spectra_dir = self.work_dir / "5_training_spectra" self.training_spectra_dir = self.work_dir / "6_Spectral_Feature_Extraction"
self.indices_dir = self.work_dir / "6_water_quality_indices" self.indices_dir = self.work_dir / "7_Water_Quality_Indices"
self.models_dir = self.work_dir / "7_Supervised_Model_Training" self.models_dir = self.work_dir / "8_Supervised_Model_Training"
self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling" self.non_empirical_models_dir = self.work_dir / "8_Non_Empirical_Regression"
self.custom_regression_dir = self.work_dir / "13_Custom_Regression" self.custom_regression_dir = self.work_dir / "13_Custom_Regression"
self.sampling_dir = self.work_dir / "4_sampling" self.sampling_dir = self.work_dir / "4_sampling"
self.prediction_dir = self.work_dir / "11_12_13_predictions" self.prediction_dir = self.work_dir / "11_12_13_predictions"
@ -764,7 +764,7 @@ class WaterQualityInversionPipeline:
if not spectrum_csv or not os.path.exists(spectrum_csv): if not spectrum_csv or not os.path.exists(spectrum_csv):
# 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv # 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv
fallback_candidates = [] fallback_candidates = []
step5_dir = os.path.join(self.work_dir, "5_Training_Spectra") step5_dir = os.path.join(self.work_dir, "6_Spectral_Feature_Extraction")
if os.path.isdir(step5_dir): if os.path.isdir(step5_dir):
for f in sorted(os.listdir(step5_dir)): for f in sorted(os.listdir(step5_dir)):
if f.lower().endswith('.csv'): if f.lower().endswith('.csv'):
@ -2023,10 +2023,10 @@ class WaterQualityInversionPipeline:
# 应用预处理 - 使用spectral_Preprocessing模块 # 应用预处理 - 使用spectral_Preprocessing模块
from src.preprocessing.spectral_Preprocessing import Preprocessing from src.preprocessing.spectral_Preprocessing import Preprocessing
# 为SS预处理提供scaler保存路径保存在工作目录的7_Supervised_Model_Training中 # 为SS预处理提供scaler保存路径保存在工作目录的8_Supervised_Model_Training中
save_path = None save_path = None
if preprocess_method == 'SS': if preprocess_method == 'SS':
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录 models_dir = output_dir.parent.parent / "8_Supervised_Model_Training" # 向上两级到工作目录
models_dir.mkdir(parents=True, exist_ok=True) models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl") save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}") print(f"SS预处理: scaler模型将保存到 {save_path}")

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Step5 面板 - 光谱提取 Step6 面板 - 光谱特征提取
""" """
import os import os
@ -27,7 +27,7 @@ class Step6FeaturePanel(QWidget):
layout = QVBoxLayout() layout = QVBoxLayout()
# 标题 # 标题
title = QLabel("步骤5训练样本光谱提取") title = QLabel("步骤6光谱特征提取")
title.setFont(QFont("Arial", 12, QFont.Bold)) title.setFont(QFont("Arial", 12, QFont.Bold))
layout.addWidget(title) layout.addWidget(title)
@ -58,12 +58,12 @@ class Step6FeaturePanel(QWidget):
"Mask Files (*.dat *.tif);;All Files (*.*)" "Mask Files (*.dat *.tif);;All Files (*.*)"
) )
layout.addWidget(self.glint_mask_file) layout.addWidget(self.glint_mask_file)
step5_glint_hint = QLabel( step6_glint_hint = QLabel(
"提示独立运行本步骤时必须选择耀斑掩膜通常为步骤2输出的 severe_glint_area.dat用于在采样时避开耀斑像元。" "提示独立运行本步骤时必须选择耀斑掩膜通常为步骤2输出的 severe_glint_area.dat用于在采样时避开耀斑像元。"
) )
step5_glint_hint.setWordWrap(True) step6_glint_hint.setWordWrap(True)
step5_glint_hint.setStyleSheet("color: #666; font-size: 10px;") step6_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
layout.addWidget(step5_glint_hint) layout.addWidget(step6_glint_hint)
# 参数设置 # 参数设置
params_group = QGroupBox("提取参数") params_group = QGroupBox("提取参数")
@ -200,20 +200,22 @@ class Step6FeaturePanel(QWidget):
else: else:
self.output_file.set_path("") self.output_file.set_path("")
# 5. 尝试从 Step4 界面读取已处理的水质参数 CSV 路径,自动填入本面板 # 5. 尝试从 Step5 Clean 界面读取已处理的清洗后 CSV 路径,自动填入本面板
main_window = self.window() main_window = self.window()
if main_window and hasattr(main_window, 'step5_panel'): if main_window and hasattr(main_window, 'step5_clean_panel'):
step4_output_path = main_window.step5_panel.output_file.get_path() step5_clean_output_path = main_window.step5_clean_panel.output_file.get_path()
if step4_output_path: if step5_clean_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径 # 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step4_output_path): if not os.path.isabs(step5_clean_output_path):
step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/') step5_clean_output_path = os.path.join(
self.work_dir or '', step5_clean_output_path
).replace('\\', '/')
existing_csv = self.csv_file.get_path() existing_csv = self.csv_file.get_path()
if not existing_csv or not existing_csv.strip(): if not existing_csv or not existing_csv.strip():
self.csv_file.set_path(step4_output_path) self.csv_file.set_path(step5_clean_output_path)
def run_step(self): def run_step(self):
"""独立运行步骤5""" """独立运行步骤6"""
# 验证输入 # 验证输入
deglint_img_path = self.deglint_img_file.get_path() deglint_img_path = self.deglint_img_file.get_path()
csv_path = self.csv_file.get_path() csv_path = self.csv_file.get_path()

View File

@ -1393,9 +1393,7 @@ class WaterQualityGUI(QMainWindow):
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), 'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
'water_mask_path': ('step1', 'water_mask', 'water_mask_file') 'water_mask_path': ('step1', 'water_mask', 'water_mask_file')
}, },
'step5_clean': { # 'step5_clean': 业务要求保持输入源独立,不自动抓取 step4_sampling 的输出;用户手动浏览导入
'csv_path': ('step4_sampling', 'sampling_spectra', 'csv_file') # step5 寻找 step4 的采样点
},
'step6_feature': { 'step6_feature': {
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), 'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
'csv_path': ('step5_clean', 'processed_data', 'csv_file'), 'csv_path': ('step5_clean', 'processed_data', 'csv_file'),
@ -2255,15 +2253,26 @@ class WaterQualityGUI(QMainWindow):
file_widget = getattr(panel, panel_attr) file_widget = getattr(panel, panel_attr)
# ★ 兼容 FileSelectWidget 与原生 QLineEdit
current_text = (
file_widget.get_path().strip()
if hasattr(file_widget, 'get_path')
else file_widget.text().strip()
)
# 如果输入框已经有内容,跳过自动填充 # 如果输入框已经有内容,跳过自动填充
if file_widget.get_path().strip(): if current_text:
continue continue
# 查找依赖步骤的输出文件 # 查找依赖步骤的输出文件
output_path = self.find_step_output(work_path, dep_step, output_type) output_path = self.find_step_output(work_path, dep_step, output_type)
if output_path and Path(output_path).exists(): if output_path and Path(output_path).exists():
file_widget.set_path(output_path) # ★ 兼容 FileSelectWidget 与原生 QLineEdit
if hasattr(file_widget, 'set_path'):
file_widget.set_path(str(output_path))
else:
file_widget.setText(str(output_path))
self.log_message(f"自动填充 {step_id}.{input_field}: {output_path}", "info") self.log_message(f"自动填充 {step_id}.{input_field}: {output_path}", "info")
filled_count += 1 filled_count += 1

View File

@ -0,0 +1,411 @@
"""
interpolator.py 多进程重构的行为测试mock-based无 osgeo 依赖)
验证:
1. 静态结构模块级函数集合、签名、新参数、_worker_dataset 全局
2. 行为逻辑_process_one_block 在 mock dataset 上的零像素识别 + 插值
3. 行为逻辑_interpolate_block_worker 通过模块全局 _worker_dataset 工作
4. 行为逻辑interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo
5. 向后兼容:现有 8 个 caller 参数保留默认值不变
如果本机有 osgeo可补一个真实数据集的 smoke test。
"""
import os
import sys
import ast
import types
import unittest
from unittest.mock import MagicMock, patch
# Ensure the project src is on path so we can import the module under test
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
INTERP_PATH = os.path.join(
PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py"
)
# =============================================================================
# Static structure tests
# =============================================================================
class TestInterpolatorStructure(unittest.TestCase):
def setUp(self):
with open(INTERP_PATH, "r", encoding="utf-8") as f:
self.src = f.read()
self.tree = ast.parse(self.src)
def test_module_level_functions(self):
mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)}
expected = {
"interpolate_pixels",
"_interpolate_single_band",
"_normalize_interpolation_method",
"_read_water_mask_to_array",
"_init_worker",
"_interpolate_block_worker",
"_process_one_block",
"interpolate_zero_pixels_batch",
}
self.assertEqual(mod_funcs, expected)
self.assertNotIn("_process_block_with_buffer", mod_funcs)
def test_worker_dataset_module_global(self):
mod_globals = set()
for n in self.tree.body:
if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name):
mod_globals.add(n.targets[0].id)
elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name):
mod_globals.add(n.target.id)
self.assertIn("_worker_dataset", mod_globals)
def test_interpolate_zero_pixels_batch_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch":
args = [a.arg for a in n.args.args]
# Backward compat: all 8 existing params + 2 new
self.assertEqual(
args,
[
"img_path", "interpolation_method", "output_path",
"water_mask", "deglint_dir", "callback_progress",
"block_size", "halo_size", "n_workers", "use_multiprocessing",
],
)
defaults = [getattr(d, "value", None) for d in n.args.defaults]
self.assertEqual(
defaults,
["nearest", None, None, None, None, 1024, 64, None, True],
)
return
self.fail("interpolate_zero_pixels_batch not found")
def test_init_worker_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_init_worker":
self.assertEqual([a.arg for a in n.args.args], ["img_path"])
return
self.fail("_init_worker not found")
def test_worker_function_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker":
self.assertEqual([a.arg for a in n.args.args], ["task"])
return
self.fail("_interpolate_block_worker not found")
def test_process_one_block_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block":
args = [a.arg for a in n.args.args]
self.assertEqual(
args,
[
"dataset", "x0", "y0",
"ey0", "ex0", "ey1", "ex1",
"row_offset", "col_offset",
"inner_h", "inner_w",
"mask_segment_ext", "method",
],
)
return
self.fail("_process_one_block not found")
def test_uses_process_pool_executor(self):
self.assertIn("ProcessPoolExecutor", self.src)
self.assertIn("ProcessPoolExecutor(", self.src)
self.assertIn("initializer=_init_worker", self.src)
self.assertIn("initargs=(img_path,)", self.src)
self.assertIn("GDAL_NUM_THREADS", self.src)
def test_dispatch_logic_present(self):
# Both serial and parallel paths should be present
self.assertIn("if effective_workers <= 1:", self.src)
self.assertIn("with ProcessPoolExecutor", self.src)
def test_serial_path_uses_process_one_block(self):
# Serial branch should call _process_one_block directly (no pickle overhead)
self.assertIn(
"_process_one_block(\n dataset, *task\n )",
self.src,
)
def test_worker_path_uses_process_one_block(self):
# Worker function should also call _process_one_block (shared core)
self.assertIn("_process_one_block(", self.src)
# The worker should read from _worker_dataset, not receive dataset in task
self.assertIn("_worker_dataset", self.src)
# =============================================================================
# Mocked behavior tests (no real osgeo/scipy needed)
# =============================================================================
def _make_mock_band(read_data):
band = MagicMock()
band.ReadAsArray.return_value = read_data
return band
def _make_mock_dataset(bands_data, n_bands=None):
if n_bands is None:
n_bands = len(bands_data)
ds = MagicMock()
ds.RasterCount = n_bands
ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1])
return ds
def _make_fake_module(name, attrs=None):
mod = types.ModuleType(name)
for k, v in (attrs or {}).items():
setattr(mod, k, v)
return mod
def _install_fake_modules():
"""Install minimal fakes for scipy + osgeo so the module under test imports."""
import numpy as np
# Fake scipy
scipy = types.ModuleType("scipy")
scipy_ndimage = types.ModuleType("scipy.ndimage")
scipy_interp = types.ModuleType("scipy.interpolate")
scipy_spatial = types.ModuleType("scipy.spatial")
class _FakeTree:
def __init__(self, coords):
self.coords = coords
def query(self, pts):
# Return index 0 for all queries
import numpy as np
n = len(pts)
return np.zeros(n, dtype=int), np.zeros(n, dtype=int)
def _griddata(coords, values, pts, method="linear", fill_value=0.0):
# Trivial: nearest valid value
if len(coords) == 0 or len(pts) == 0:
return np.zeros(len(pts), dtype=np.float32)
# Just return first value for all queries (testing only)
return np.full(len(pts), values[0], dtype=np.float32)
class _FakeRBF:
def __init__(self, coords, values, kernel=None):
self.first_value = values[0]
def __call__(self, pts):
return np.full(len(pts), self.first_value, dtype=np.float32)
scipy_spatial.cKDTree = _FakeTree
scipy_interp.griddata = _griddata
scipy_interp.RBFInterpolator = _FakeRBF
scipy.ndimage = scipy_ndimage
scipy.interpolate = scipy_interp
scipy.spatial = scipy_spatial
# Fake osgeo.gdal with the constants the module needs
osgeo = types.ModuleType("osgeo")
gdal_mod = types.ModuleType("osgeo.gdal")
gdal_mod.GA_ReadOnly = 0
gdal_mod.GDT_Float32 = 6
gdal_mod.UseExceptions = MagicMock()
def _open(path, mode):
return _make_mock_dataset([]) # empty dataset; real tests build their own
gdal_mod.Open = _open
gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock())
gdal_mod.SetConfigOption = MagicMock()
osgeo.gdal = gdal_mod
sys.modules["scipy"] = scipy
sys.modules["scipy.ndimage"] = scipy_ndimage
sys.modules["scipy.interpolate"] = scipy_interp
sys.modules["scipy.spatial"] = scipy_spatial
sys.modules["osgeo"] = osgeo
sys.modules["osgeo.gdal"] = gdal_mod
return {
"scipy": scipy,
"scipy.spatial": scipy_spatial,
"scipy.interpolate": scipy_interp,
"osgeo.gdal": gdal_mod,
}
class TestProcessOneBlockMocked(unittest.TestCase):
"""Verify _process_one_block logic with mocked GDAL/scipy."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
# Import the module under test after installing fakes
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
cls.gdal = sys.modules["osgeo.gdal"]
def _build_dataset(self, bands):
"""Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h)
returns bands[b-1] (a numpy array of the requested shape).
For simplicity we return the full band array regardless of x,y,w,h.
"""
ds = MagicMock()
ds.RasterCount = len(bands)
def get_band(b):
band = MagicMock()
band.ReadAsArray.return_value = bands[b - 1]
return band
ds.GetRasterBand.side_effect = get_band
return ds
def test_no_zeros_returns_inner_blocks_unchanged(self):
"""If no pixels are all-zero, inner blocks should be returned as-is."""
import numpy as np
# 2x2 image, 3 bands, all 1s (no zeros anywhere)
band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo
bands = [band, band * 2, band * 3]
ds = self._build_dataset(bands)
inner_bands, zero_count = self.interp._process_one_block(
dataset=ds,
x0=0, y0=0,
ey0=0, ex0=0, ey1=4, ex1=4,
row_offset=0, col_offset=0,
inner_h=2, inner_w=2,
mask_segment_ext=None,
method="nearest",
)
self.assertEqual(zero_count, 0)
self.assertEqual(len(inner_bands), 3)
for ib, expected_band in zip(inner_bands, bands):
self.assertEqual(ib.shape, (2, 2))
np.testing.assert_array_equal(ib, expected_band[:2, :2])
def test_with_zero_pixel_triggers_interpolation(self):
"""If a pixel is all-zero, _interpolate_single_band should be called."""
import numpy as np
# 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s.
band1 = np.array([
[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
], dtype=np.float32)
band2 = band1 * 2
band3 = band1 * 3
ds = self._build_dataset([band1, band2, band3])
# Process the top-left 2x2 block (with halo, 4x4 covers all)
inner_bands, zero_count = self.interp._process_one_block(
dataset=ds,
x0=0, y0=0,
ey0=0, ex0=0, ey1=4, ex1=4,
row_offset=0, col_offset=0,
inner_h=2, inner_w=2,
mask_segment_ext=None,
method="nearest",
)
self.assertGreater(zero_count, 0, "should detect zero pixels")
self.assertEqual(len(inner_bands), 3)
# Each inner band should be 2x2
for ib in inner_bands:
self.assertEqual(ib.shape, (2, 2))
# Our fake nearest interpolation returns valid_values[0]=1 for band1
# So the inner block should be filled with non-zero values
self.assertTrue(np.all(inner_bands[0] > 0))
class TestWorkerFunctionMocked(unittest.TestCase):
"""Verify _interpolate_block_worker uses module-global _worker_dataset."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
def test_worker_uses_module_global_dataset(self):
import numpy as np
# Set the module global
band = np.array([
[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
], dtype=np.float32)
ds = self._build_dataset([band, band * 2])
self.interp._worker_dataset = ds
try:
task = (
0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1
0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w
None, "nearest", # mask_segment_ext, method
)
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
self.assertIsNone(error)
self.assertEqual((x0, y0), (0, 0))
self.assertGreater(zero_count, 0)
self.assertIsNotNone(inner_bands)
self.assertEqual(len(inner_bands), 2)
finally:
self.interp._worker_dataset = None
def test_worker_returns_error_if_dataset_uninitialized(self):
self.interp._worker_dataset = None
task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest")
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
self.assertIsNotNone(error)
self.assertIn("not initialized", error)
self.assertIsNone(inner_bands)
def _build_dataset(self, bands):
ds = MagicMock()
ds.RasterCount = len(bands)
def get_band(b):
band = MagicMock()
band.ReadAsArray.return_value = bands[b - 1]
return band
ds.GetRasterBand.side_effect = get_band
return ds
class TestInitWorkerMocked(unittest.TestCase):
"""Verify _init_worker sets module global and config option."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
cls.gdal_mod = sys.modules["osgeo.gdal"]
def setUp(self):
self.interp._worker_dataset = None
def tearDown(self):
self.interp._worker_dataset = None
def test_init_worker_opens_and_caches(self):
fake_ds = MagicMock()
with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \
patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock:
self.interp._init_worker("/fake/path.bsq")
self.assertIs(self.interp._worker_dataset, fake_ds)
open_mock.assert_called_once_with("/fake/path.bsq", 0)
# Should have set GDAL_NUM_THREADS=1
cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1")
def test_init_worker_raises_if_open_fails(self):
with patch.object(self.gdal_mod, "Open", return_value=None):
with self.assertRaises(RuntimeError):
self.interp._init_worker("/bad/path.bsq")
if __name__ == "__main__":
unittest.main(verbosity=2)