From 60a9d7d9227e336528b2f874c0317f0bd8ab2531 Mon Sep 17 00:00:00 2001 From: DXC Date: Mon, 15 Jun 2026 16:49:17 +0800 Subject: [PATCH] =?UTF-8?q?Step3=20=E6=8F=92=E5=80=BC=E7=AE=97=E6=B3=95=20?= =?UTF-8?q?OOM=20=E4=BF=AE=E5=A4=8D=20+=20=E5=A4=9A=E8=BF=9B=E7=A8=8B?= =?UTF-8?q?=E5=8A=A0=E9=80=9F=20+=20=E5=85=A8=E9=93=BE=E8=B7=AF=E7=B4=AF?= =?UTF-8?q?=E7=A7=AF=E6=94=B9=E5=8A=A8=EF=BC=8814=20=E6=96=87=E4=BB=B6?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../algorithms/interpolation/interpolator.py | 479 ++++++++++++++---- src/core/glint_removal/get_spectral-test.py | 2 +- src/core/glint_removal/get_spectral.py | 4 +- src/core/modeling/regression.py | 6 +- src/core/non_empirical_retrieval.py | 2 +- src/core/pipeline/runner.py | 10 +- src/core/prediction/automl_trainer.py | 4 +- src/core/steps/data_preparation_step.py | 6 +- src/core/steps/modeling_step.py | 6 +- src/core/steps/prediction_step.py | 2 +- .../water_quality_inversion_pipeline_GUI.py | 16 +- src/gui/panels/step6_feature_panel.py | 30 +- src/gui/water_quality_gui.py | 29 +- tests/test_interpolator_refactor.py | 411 +++++++++++++++ 14 files changed, 855 insertions(+), 152 deletions(-) create mode 100644 tests/test_interpolator_refactor.py diff --git a/src/core/algorithms/interpolation/interpolator.py b/src/core/algorithms/interpolation/interpolator.py index 3d1f370..6962065 100644 --- a/src/core/algorithms/interpolation/interpolator.py +++ b/src/core/algorithms/interpolation/interpolator.py @@ -3,8 +3,24 @@ 提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。 支持多种插值方法:nearest, bilinear, spline (RBF), kriging。 + +本模块使用多进程并行分块 IO 加速(Plan A): +- ProcessPoolExecutor 为每个 worker 进程打开一次源影像(initializer 阶段), + 避免每块重复 gdal.Open 带来的开销(Windows 上 ~50ms/次) +- 主进程统一负责输出文件的写入,避免多进程写锁竞争 +- 分块大小(block_size)默认 1024,内存充足可调至 2048 / 4096 + +注意: +- GDAL Dataset / Rasterio Dataset 对象不能跨进程传递(picking 不支持), + 所以 worker 必须在 init 阶段自己独立打开源文件 +- 每个 worker 强制设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 多线程 + 造成的 CPU 过订阅 +- 关闭多进程:传 ``use_multiprocessing=False`` 或 ``n_workers=1`` """ +import multiprocessing +from concurrent.futures import ProcessPoolExecutor + import numpy as np from typing import Optional, Union, Tuple, List from pathlib import Path @@ -24,6 +40,9 @@ except ImportError: GDAL_AVAILABLE = False +_worker_dataset: Optional["gdal.Dataset"] = None + + def interpolate_pixels( image_stack: np.ndarray, zero_coords: np.ndarray, @@ -52,7 +71,6 @@ def interpolate_pixels( height, width, n_bands = image_stack.shape result = image_stack.copy() - # 兼容中文和各种格式的method参数 raw_method = str(interpolation_method).lower() if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method: method = 'nearest' @@ -181,39 +199,271 @@ def _interpolate_single_band( return np.zeros(len(zero_coords)) +def _normalize_interpolation_method(method: str) -> str: + """将中文/英文混用的插值方法名归一化为内部标准名 + + 支持: 'nearest'/'邻近'/'最邻近','bilinear'/'线性'/'双线性', + 'spline'/'样条'/'rbf','kriging'/'克里金'。 + """ + raw = str(method).lower() + if 'nearest' in raw or '邻近' in raw or '最邻近' in raw: + return 'nearest' + if 'bilinear' in raw or '线性' in raw or '双线性' in raw: + return 'bilinear' + if 'spline' in raw or '样条' in raw or 'rbf' in raw: + return 'spline' + if 'kriging' in raw or '克里金' in raw: + return 'kriging' + return 'nearest' + + +def _read_water_mask_to_array( + water_mask: Optional[Union[str, np.ndarray]], + expected_height: int, + expected_width: int, +) -> Optional[np.ndarray]: + """读取水域掩膜为 numpy 数组(单波段,bool/int 均可) + + None 或空字符串直接返回 None。形状不匹配时给出告警但不抛错, + 让调用方按"无掩膜"路径继续。 + """ + if water_mask is None: + return None + if isinstance(water_mask, str): + if not water_mask.strip(): + return None + mask_ds = gdal.Open(water_mask, gdal.GA_ReadOnly) + if mask_ds is None: + print(f" [warn] 无法打开水域掩膜 {water_mask},按无掩膜处理") + return None + try: + mask_array = mask_ds.GetRasterBand(1).ReadAsArray() + finally: + mask_ds = None + elif isinstance(water_mask, np.ndarray): + mask_array = water_mask + else: + return None + + if mask_array.shape != (expected_height, expected_width): + print( + f" [warn] 水域掩膜形状 {mask_array.shape} 与影像 " + f"({expected_height}, {expected_width}) 不匹配,按无掩膜处理" + ) + return None + return mask_array + + +def _init_worker(img_path: str) -> None: + """ProcessPoolExecutor initializer: 每个 worker 进程只调用一次 + + 在 worker 进程启动时打开源影像 dataset 并缓存在模块全局变量 + ``_worker_dataset`` 中。后续所有块处理直接复用这个 dataset, + 避免每块重复 ``gdal.Open``(Windows 上约 50ms/次,100 块即 5s)。 + + 同时设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 默认多线程 + 造成的 CPU 过订阅。 + """ + global _worker_dataset + gdal.SetConfigOption('GDAL_NUM_THREADS', '1') + if hasattr(gdal, 'UseExceptions'): + gdal.UseExceptions() + _worker_dataset = gdal.Open(img_path, gdal.GA_ReadOnly) + if _worker_dataset is None: + raise RuntimeError(f"Worker failed to open source image: {img_path}") + + +def _interpolate_block_worker(task: tuple) -> tuple: + """ProcessPoolExecutor worker: 处理单个块并返回结果 + + 该函数必须保持模块级(可被 pickle),不持有任何外部状态—— + 源 dataset 通过 ``_worker_dataset`` 模块全局变量获取。 + + Returns: + ``(x0, y0, inner_bands, zero_count, error_msg)`` 元组: + - x0, y0: 块在影像中的写入起点 + - inner_bands: ``List[np.ndarray]``,每个元素是 (inner_h, inner_w) + float32 数组(每个波段一个),或失败时为 None + - zero_count: 该扩展块中识别到的零像素数(含 halo 范围) + - error_msg: None 表示成功,str 表示错误信息 + """ + ( + x0, y0, ey0, ex0, ey1, ex1, + row_offset, col_offset, inner_h, inner_w, + mask_segment_ext, method, + ) = task + if _worker_dataset is None: + return (x0, y0, None, 0, "Worker dataset not initialized") + try: + inner_bands, zero_count = _process_one_block( + _worker_dataset, x0, y0, ey0, ex0, ey1, ex1, + row_offset, col_offset, inner_h, inner_w, + mask_segment_ext, method, + ) + return (x0, y0, inner_bands, zero_count, None) + except Exception as e: + return (x0, y0, None, 0, str(e)) + + +def _process_one_block( + dataset: "gdal.Dataset", + x0: int, y0: int, + ey0: int, ex0: int, ey1: int, ex1: int, + row_offset: int, col_offset: int, + inner_h: int, inner_w: int, + mask_segment_ext: Optional[np.ndarray], + method: str, +) -> Tuple[List[np.ndarray], int]: + """处理单个扩展块(纯计算核心,dataset 显式传入) + + 串行模式和并行模式共用此函数。并行模式下 dataset 来自 worker 的 + 缓存(``_worker_dataset``),串行模式下 dataset 由主函数传入。 + + Args: + dataset: 已打开的源影像 dataset + x0, y0: 内部块左上角(写入位置) + ey0, ex0, ey1, ex1: 扩展块(含 halo)坐标 + row_offset, col_offset: 内部块在扩展块中的偏移 + inner_h, inner_w: 内部块尺寸 + mask_segment_ext: 扩展块对应的水域掩膜(None 表示不应用) + method: 插值方法(已归一化) + + Returns: + ``(inner_bands, zero_count)`` 元组: + - inner_bands: ``List[np.ndarray]``,长度 = n_bands,每个元素形状为 + ``(inner_h, inner_w)`` 的 float32 数组 + - zero_count: 扩展块中识别到的零像素数 + """ + n_bands = dataset.RasterCount + ext_bands: List[np.ndarray] = [] + for b in range(1, n_bands + 1): + band = dataset.GetRasterBand(b) + ext_bands.append( + band.ReadAsArray(ex0, ey0, ex1 - ex0, ey1 - ey0).astype(np.float32) + ) + band = None + + try: + ext_h, ext_w = ey1 - ey0, ex1 - ex0 + + all_zero_ext = np.ones((ext_h, ext_w), dtype=bool) + for b_data in ext_bands: + all_zero_ext &= (b_data == 0) + + if mask_segment_ext is not None: + all_zero_ext &= (mask_segment_ext > 0) + + zero_count = int(np.sum(all_zero_ext)) + + if zero_count == 0: + inner_bands = [ + ext_bands[b][ + row_offset:row_offset + inner_h, + col_offset:col_offset + inner_w, + ] + for b in range(n_bands) + ] + return inner_bands, 0 + + zero_y, zero_x = np.where(all_zero_ext) + zero_coords = np.column_stack([zero_x, zero_y]) + + valid_mask = ~all_zero_ext + valid_y, valid_x = np.where(valid_mask) + valid_coords = np.column_stack([valid_x, valid_y]) + + if len(valid_coords) == 0: + print( + f" [warn] 块 (y={y0}-{y0 + inner_h}, x={x0}-{x0 + inner_w}) " + f"无有效像素可作插值上下文,已跳过" + ) + inner_bands = [ + ext_bands[b][ + row_offset:row_offset + inner_h, + col_offset:col_offset + inner_w, + ] + for b in range(n_bands) + ] + return inner_bands, zero_count + + for b in range(n_bands): + ext_band = ext_bands[b] + valid_values_band = ext_band[valid_mask] + if len(valid_values_band) == 0: + continue + band_result = _interpolate_single_band( + zero_coords, valid_coords, valid_values_band, method + ) + ext_band[zero_y, zero_x] = band_result + + inner_bands = [ + ext_bands[b][ + row_offset:row_offset + inner_h, + col_offset:col_offset + inner_w, + ] + for b in range(n_bands) + ] + return inner_bands, zero_count + finally: + del ext_bands + + def interpolate_zero_pixels_batch( img_path: str, interpolation_method: str = 'nearest', output_path: Optional[str] = None, water_mask: Optional[Union[str, np.ndarray]] = None, deglint_dir: Optional[str] = None, - callback_progress: Optional[callable] = None + callback_progress: Optional[callable] = None, + block_size: int = 1024, + halo_size: int = 64, + n_workers: Optional[int] = None, + use_multiprocessing: bool = True, ) -> Tuple[str, Optional[np.ndarray]]: """ - 对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O) + 对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)。 + + 采用 **分块 IO + 多进程并行** 策略: + 1. 影像按 ``block_size`` × ``block_size`` 分块,每块边界外扩展 + ``halo_size`` 像素作为插值上下文,避免块边缘插值退化 + 2. 多进程并行(默认 ``ProcessPoolExecutor``,worker 数 = CPU 核心数) + 并发处理所有块;GDAL Dataset 不能跨进程传递,所以每个 worker + 在 ``initializer`` 阶段独立打开源文件一次并缓存 + 3. 主进程按块序接收处理结果并统一写入输出文件,避免写锁竞争 + 4. 该方案可彻底避免一次性读取 50 波段整景影像时的 OOM 隐患 + (50 波段 × 4000×4000 × float32 ≈ 3GB 的 np.dstack) Args: img_path: 输入影像文件路径 - interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging' - output_path: 输出文件路径(如果为None,自动生成) - water_mask: 水域掩膜(文件路径或数组) + interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', + 'kriging' 及其中文别名('邻近'/'最邻近'/'线性'/'双线性'/'样条'/'克里金') + output_path: 输出文件路径(如果为 None 且 deglint_dir 提供,自动生成) + water_mask: 水域掩膜(文件路径或数组),形状须与影像高宽一致 deglint_dir: 去耀斑目录(用于生成默认输出路径) - callback_progress: 进度回调函数 + callback_progress: 进度回调函数,签名 ``callback(msg: str)`` + block_size: 分块大小(像素),默认 1024;内存充足可调 2048/4096 + halo_size: 上下文 halo 宽度(像素),默认 64 + n_workers: 并行 worker 进程数;None = ``multiprocessing.cpu_count()``; + 传 1 等价于串行模式 + use_multiprocessing: 是否启用多进程;False 时强制串行 Returns: - (output_path, interpolated_image_stack) 元组 + ``(output_path, None)`` 元组。第二个值固定为 ``None``(与原版语义保留 + 兼容;返回完整内存堆叠会重新引入 OOM 风险,故不再提供)。 """ if not SCIPY_AVAILABLE: raise ImportError("scipy未安装,无法进行0值像素插值") if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取影像文件") - # 确定输出路径 - if output_path is None and deglint_dir is not None: - output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq") + method = _normalize_interpolation_method(interpolation_method) - # 检查文件是否已存在 - if output_path and Path(output_path).exists(): + if output_path is None and deglint_dir is not None: + output_path = str(Path(deglint_dir) / f"interpolated_{method}.bsq") + if output_path is None: + raise ValueError("output_path 和 deglint_dir 至少需要指定一个") + + if Path(output_path).exists(): return output_path, None dataset = gdal.Open(img_path, gdal.GA_ReadOnly) @@ -227,94 +477,125 @@ def interpolate_zero_pixels_batch( geotransform = dataset.GetGeoTransform() projection = dataset.GetProjection() - # 读取所有波段数据 - all_bands = [] - for band_idx in range(1, n_bands + 1): - band = dataset.GetRasterBand(band_idx) - band_data = band.ReadAsArray().astype(np.float32) - all_bands.append(band_data) - - image_stack = np.dstack(all_bands) - - # 读取水域掩膜 - mask_array = None - if water_mask is not None: - if isinstance(water_mask, str): - mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) - if mask_dataset: - mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() - mask_dataset = None - elif isinstance(water_mask, np.ndarray): - mask_array = water_mask - - # 找出所有波段都为0的像素点 - all_bands_zero = np.all(image_stack == 0, axis=2) - - if mask_array is not None: - all_bands_zero = all_bands_zero & (mask_array > 0) - - zero_pixel_count = np.sum(all_bands_zero) - if zero_pixel_count == 0: - # 无需插值,直接保存 - if output_path: - driver = gdal.GetDriverByName('ENVI') - if driver is None: - driver = gdal.GetDriverByName('GTiff') - out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32) - out_dataset.SetGeoTransform(geotransform) - out_dataset.SetProjection(projection) - for i, band_data in enumerate(all_bands): - out_band = out_dataset.GetRasterBand(i + 1) - out_band.WriteArray(band_data) - out_band.FlushCache() - out_dataset = None - return output_path, image_stack - - # 获取坐标 - zero_y, zero_x = np.where(all_bands_zero) - zero_coords = np.column_stack([zero_x, zero_y]) - - valid_mask = ~all_bands_zero - valid_y, valid_x = np.where(valid_mask) - valid_coords = np.column_stack([valid_x, valid_y]) - - if len(valid_coords) == 0: - raise ValueError("没有有效像素可用于插值") - - # 逐波段插值 - interpolated_bands = [] - for band_idx in range(n_bands): - if callback_progress: - callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...") - band_data = all_bands[band_idx].copy() - valid_values_band = band_data[valid_mask] - - if len(valid_values_band) == 0: - interpolated_bands.append(band_data) - continue - - band_result = _interpolate_single_band( - zero_coords, valid_coords, valid_values_band, interpolation_method + if width <= 0 or height <= 0 or n_bands <= 0: + raise ValueError( + f"影像尺寸异常: width={width}, height={height}, n_bands={n_bands}" ) - band_data[all_bands_zero] = band_result - interpolated_bands.append(band_data) - # 保存结果 - if output_path: - driver = gdal.GetDriverByName('ENVI') - if driver is None: - driver = gdal.GetDriverByName('GTiff') - out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32) - out_dataset.SetGeoTransform(geotransform) - out_dataset.SetProjection(projection) - for i, band_data in enumerate(interpolated_bands): - out_band = out_dataset.GetRasterBand(i + 1) - out_band.WriteArray(band_data) - out_band.FlushCache() + mask_array = _read_water_mask_to_array(water_mask, height, width) + + driver = gdal.GetDriverByName('ENVI') + if driver is None: + driver = gdal.GetDriverByName('GTiff') + if driver is None: + raise RuntimeError("未找到可用的栅格驱动(ENVI / GTiff 都不存在)") + + out_dataset = driver.Create( + output_path, width, height, n_bands, gdal.GDT_Float32 + ) + if out_dataset is None: + raise RuntimeError(f"无法创建输出文件: {output_path}") + out_dataset.SetGeoTransform(geotransform) + out_dataset.SetProjection(projection) + + try: + if not use_multiprocessing: + effective_workers = 1 + elif n_workers is not None and n_workers >= 1: + effective_workers = int(n_workers) + else: + try: + cpu_count = multiprocessing.cpu_count() or 1 + except (NotImplementedError, OSError): + cpu_count = 1 + effective_workers = max(1, cpu_count) + + n_blocks_y = (height + block_size - 1) // block_size + n_blocks_x = (width + block_size - 1) // block_size + total_blocks = n_blocks_y * n_blocks_x + + tasks = [] + for by in range(n_blocks_y): + y0 = by * block_size + y1 = min(y0 + block_size, height) + inner_h = y1 - y0 + ey0 = max(0, y0 - halo_size) + ey1 = min(height, y1 + halo_size) + for bx in range(n_blocks_x): + x0 = bx * block_size + x1 = min(x0 + block_size, width) + inner_w = x1 - x0 + ex0 = max(0, x0 - halo_size) + ex1 = min(width, x1 + halo_size) + row_offset = y0 - ey0 + col_offset = x0 - ex0 + mask_segment_ext = None + if mask_array is not None: + mask_segment_ext = mask_array[ey0:ey1, ex0:ex1] + tasks.append(( + x0, y0, ey0, ex0, ey1, ex1, + row_offset, col_offset, inner_h, inner_w, + mask_segment_ext, method, + )) + + if callback_progress: + callback_progress( + f"分块插值开始: 共 {total_blocks} 块 " + f"(block_size={block_size}, halo={halo_size}, method={method}, " + f"workers={effective_workers})" + ) + + total_zero_pixels = 0 + + if effective_workers <= 1: + for block_idx, task in enumerate(tasks, 1): + x0_t, y0_t = task[0], task[1] + if callback_progress: + callback_progress( + f"块 {block_idx}/{total_blocks} " + f"y=[{y0_t},{y0_t + task[8]}) x=[{x0_t},{x0_t + task[9]})" + ) + inner_bands, zero_count = _process_one_block( + dataset, *task + ) + for b_idx, band_data in enumerate(inner_bands): + out_dataset.GetRasterBand(b_idx + 1).WriteArray( + band_data, xoff=x0_t, yoff=y0_t + ) + total_zero_pixels += zero_count + else: + with ProcessPoolExecutor( + max_workers=effective_workers, + initializer=_init_worker, + initargs=(img_path,), + ) as executor: + futures = [ + executor.submit(_interpolate_block_worker, task) + for task in tasks + ] + for block_idx, future in enumerate(futures, 1): + x0_t, y0_t, inner_bands, zero_count, error = future.result() + if error is not None: + raise RuntimeError( + f"块 (y={y0_t}, x={x0_t}) 处理失败: {error}" + ) + if inner_bands is not None: + for b_idx, band_data in enumerate(inner_bands): + out_dataset.GetRasterBand(b_idx + 1).WriteArray( + band_data, xoff=x0_t, yoff=y0_t + ) + total_zero_pixels += zero_count + if callback_progress: + callback_progress(f"已写入块 {block_idx}/{total_blocks}") + + if callback_progress: + callback_progress( + f"分块插值完成: 共处理 {total_zero_pixels} 个零像素 " + f"({total_blocks} 块,方法 {method},workers={effective_workers})" + ) + + return output_path, None + finally: out_dataset = None - - result_stack = np.dstack(interpolated_bands) - return output_path, result_stack - finally: dataset = None diff --git a/src/core/glint_removal/get_spectral-test.py b/src/core/glint_removal/get_spectral-test.py index b20908d..b90a9ca 100644 --- a/src/core/glint_removal/get_spectral-test.py +++ b/src/core/glint_removal/get_spectral-test.py @@ -899,7 +899,7 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None, if __name__ == '__main__': # 在这里直接设置参数 imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径 - coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度) + coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度) output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径 radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均 diff --git a/src/core/glint_removal/get_spectral.py b/src/core/glint_removal/get_spectral.py index 5784813..b8ba79a 100644 --- a/src/core/glint_removal/get_spectral.py +++ b/src/core/glint_removal/get_spectral.py @@ -806,8 +806,8 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None, if __name__ == '__main__': # 在这里直接设置参数 imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径 - coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度) - output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径 + coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度) + output_path = r"E:\code\WQ\封装\work_dir\6_Spectral_Feature_Extraction/yangdian_output.csv" # CSV格式输出文件路径 radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均 flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用) diff --git a/src/core/modeling/regression.py b/src/core/modeling/regression.py index 5a7efa1..e69651d 100644 --- a/src/core/modeling/regression.py +++ b/src/core/modeling/regression.py @@ -315,7 +315,7 @@ def main(): # 示例1: 使用所有回归方法分析光谱指数 print("\n1. 光谱指数与叶绿素a的回归分析:") - sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv") + sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\water_quality_results.csv") spectral_indices = ['Al10SABI','Am092Bsub'] results1 = analyzer.batch_single_variable_regression( @@ -323,7 +323,7 @@ def main(): x_columns=spectral_indices, y_column='Chlorophyll', methods='all', - output_file=r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\spectral_indices_regression.csv' + output_file=r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\spectral_indices_regression.csv' ) # # 示例2: 使用特定方法分析反射率波段 @@ -343,7 +343,7 @@ def main(): best_models = analyzer.get_best_models_summary() if not best_models.empty: print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False)) - best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\best_models_summary.csv', index=False) + best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\best_models_summary.csv', index=False) print("\n最佳模型汇总已保存到 'best_models_summary.csv'") # # def advanced_usage_example(): diff --git a/src/core/non_empirical_retrieval.py b/src/core/non_empirical_retrieval.py index cad19d0..964c03a 100644 --- a/src/core/non_empirical_retrieval.py +++ b/src/core/non_empirical_retrieval.py @@ -246,7 +246,7 @@ def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, outp if __name__ == "__main__": algorithm= "chl_a" - model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json" + model_info_path= r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\8_non_empirical_models\SS\SS_chl_a.json" coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv" output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv" wave_radius=5.0 diff --git a/src/core/pipeline/runner.py b/src/core/pipeline/runner.py index 7481a40..28b70fc 100644 --- a/src/core/pipeline/runner.py +++ b/src/core/pipeline/runner.py @@ -98,7 +98,7 @@ PIPELINE_STEPS: List[StepSpec] = [ step_id="step4", method_name="step5_process_csv", requires=["csv_path"], produces=["processed_csv_path"], required_input_files=["csv_path"], - output_file="{work_dir}/4_processed_data/processed_data.csv", + output_file="{work_dir}/5_Data_Cleaning/processed_data.csv", description="CSV 异常值清洗", ), StepSpec( @@ -111,21 +111,21 @@ PIPELINE_STEPS: List[StepSpec] = [ }, skip_when_missing=False, required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"], - output_file="{work_dir}/5_training_spectra/training_spectra.csv", + output_file="{work_dir}/6_Spectral_Feature_Extraction/training_spectra.csv", description="实测样本点光谱提取", ), StepSpec( step_id="step7", method_name="step7_calc_indices", requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"], required_input_files=["training_csv_path"], - output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv", + output_file="{work_dir}/7_Water_Quality_Indices/training_spectra_indices.csv", description="水质参数指数计算(双轨输出:A轨宽表 + B轨单文件)", ), StepSpec( step_id="step8", method_name="step8_train_ml", requires=["training_csv_path"], produces=["models_dir"], required_input_files=["training_csv_path"], - output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl", + output_file="{work_dir}/8_Supervised_Model_Training/best_models.pkl", description="ML 建模(GridSearchCV / AutoML)", ), StepSpec( @@ -134,7 +134,7 @@ PIPELINE_STEPS: List[StepSpec] = [ requires=["training_csv_path"], produces=["models_dir"], parameter_map={"training_csv_path": "csv_path"}, required_input_files=["training_csv_path"], - output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl", + output_file="{work_dir}/8_Non_Empirical_Regression/non_empirical_models.pkl", description="非经验统计回归", ), StepSpec( diff --git a/src/core/prediction/automl_trainer.py b/src/core/prediction/automl_trainer.py index 8075a47..bd8c861 100644 --- a/src/core/prediction/automl_trainer.py +++ b/src/core/prediction/automl_trainer.py @@ -328,7 +328,7 @@ def train_with_automl( split_method = split_methods[0] if output_dir is None: - output_dir = "./7_Supervised_Model_Training_AutoML" + output_dir = "./8_Supervised_Model_Training_AutoML" out_dir = Path(output_dir) out_dir.mkdir(parents=True, exist_ok=True) preproc_dir = out_dir / preproc @@ -519,7 +519,7 @@ if __name__ == "__main__": p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS) p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT) p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES) - p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML") + p.add_argument("--out", default="./8_Supervised_Model_Training_AutoML") args = p.parse_args() # 智能推断 feature_start_column 类型 diff --git a/src/core/steps/data_preparation_step.py b/src/core/steps/data_preparation_step.py index db15fac..446645e 100644 --- a/src/core/steps/data_preparation_step.py +++ b/src/core/steps/data_preparation_step.py @@ -21,7 +21,7 @@ class DataPreparationStep: @staticmethod def process_csv( csv_path: str, - output_dir: Union[str, Path] = "./4_processed_data", + output_dir: Union[str, Path] = "./5_Data_Cleaning", callback: Optional[Callable] = None, ) -> str: """处理CSV文件(筛选剔除异常值)""" @@ -61,7 +61,7 @@ class DataPreparationStep: boundary_path: Optional[str] = None, glint_mask_path: Optional[str] = None, water_mask_path: Optional[str] = None, - output_dir: Union[str, Path] = "./5_training_spectra", + output_dir: Union[str, Path] = "./6_Spectral_Feature_Extraction", callback: Optional[Callable] = None, ) -> str: """根据采样点坐标在去耀斑影像中提取平均光谱""" @@ -131,7 +131,7 @@ class DataPreparationStep: formula_names: Optional[List[str]] = None, output_file: Optional[str] = None, enabled: bool = True, - output_dir: Union[str, Path] = "./6_water_quality_indices", + output_dir: Union[str, Path] = "./7_Water_Quality_Indices", callback: Optional[Callable] = None, ) -> Optional[str]: """根据训练光谱计算水质光谱指数(使用 band_math 方法)""" diff --git a/src/core/steps/modeling_step.py b/src/core/steps/modeling_step.py index 18c1b05..0c19b90 100644 --- a/src/core/steps/modeling_step.py +++ b/src/core/steps/modeling_step.py @@ -135,7 +135,7 @@ class ModelingStep: split_methods: Optional[List[str]] = None, cv_folds: int = 5, training_csv_path: Optional[str] = None, - output_dir: Union[str, Path] = "./7_Supervised_Model_Training", + output_dir: Union[str, Path] = "./8_Supervised_Model_Training", callback: Optional[Callable] = None, _report_generator=None, ) -> str: @@ -251,7 +251,7 @@ class ModelingStep: if output_dir is not None: non_empirical_dir = Path(output_dir) else: - non_empirical_dir = Path.cwd() / "8_Regression_Modeling" + non_empirical_dir = Path.cwd() / "8_Non_Empirical_Regression" non_empirical_dir.mkdir(parents=True, exist_ok=True) if preprocessing_methods is None: @@ -430,7 +430,7 @@ def _apply_preprocessing_internal( save_path = None if preprocess_method == "SS": - models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" + models_dir = output_dir.parent.parent / "8_Supervised_Model_Training" models_dir.mkdir(parents=True, exist_ok=True) save_path = str(models_dir / "scaler_params.pkl") print(f"SS预处理: scaler模型将保存到 {save_path}") diff --git a/src/core/steps/prediction_step.py b/src/core/steps/prediction_step.py index 9d59f32..37c6952 100644 --- a/src/core/steps/prediction_step.py +++ b/src/core/steps/prediction_step.py @@ -259,7 +259,7 @@ class PredictionStep: if non_empirical_models_dir is not None: final_models_dir = non_empirical_models_dir else: - default_models_dir = str(Path(work_dir) / "8_Regression_Modeling") + default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression") if Path(default_models_dir).exists(): final_models_dir = default_models_dir else: diff --git a/src/core/water_quality_inversion_pipeline_GUI.py b/src/core/water_quality_inversion_pipeline_GUI.py index effe620..f1a6b35 100644 --- a/src/core/water_quality_inversion_pipeline_GUI.py +++ b/src/core/water_quality_inversion_pipeline_GUI.py @@ -138,11 +138,11 @@ class WaterQualityInversionPipeline: self.water_mask_dir = self.work_dir / "1_water_mask" self.glint_dir = self.work_dir / "2_Glint_Detection" self.deglint_dir = self.work_dir / "3_deglint" - self.processed_data_dir = self.work_dir / "4_processed_data" - self.training_spectra_dir = self.work_dir / "5_training_spectra" - self.indices_dir = self.work_dir / "6_water_quality_indices" - self.models_dir = self.work_dir / "7_Supervised_Model_Training" - self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling" + self.processed_data_dir = self.work_dir / "5_Data_Cleaning" + self.training_spectra_dir = self.work_dir / "6_Spectral_Feature_Extraction" + self.indices_dir = self.work_dir / "7_Water_Quality_Indices" + self.models_dir = self.work_dir / "8_Supervised_Model_Training" + self.non_empirical_models_dir = self.work_dir / "8_Non_Empirical_Regression" self.custom_regression_dir = self.work_dir / "13_Custom_Regression" self.sampling_dir = self.work_dir / "4_sampling" self.prediction_dir = self.work_dir / "11_12_13_predictions" @@ -764,7 +764,7 @@ class WaterQualityInversionPipeline: if not spectrum_csv or not os.path.exists(spectrum_csv): # 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv fallback_candidates = [] - step5_dir = os.path.join(self.work_dir, "5_Training_Spectra") + step5_dir = os.path.join(self.work_dir, "6_Spectral_Feature_Extraction") if os.path.isdir(step5_dir): for f in sorted(os.listdir(step5_dir)): if f.lower().endswith('.csv'): @@ -2023,10 +2023,10 @@ class WaterQualityInversionPipeline: # 应用预处理 - 使用spectral_Preprocessing模块 from src.preprocessing.spectral_Preprocessing import Preprocessing - # 为SS预处理提供scaler保存路径,保存在工作目录的7_Supervised_Model_Training中 + # 为SS预处理提供scaler保存路径,保存在工作目录的8_Supervised_Model_Training中 save_path = None if preprocess_method == 'SS': - models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录 + models_dir = output_dir.parent.parent / "8_Supervised_Model_Training" # 向上两级到工作目录 models_dir.mkdir(parents=True, exist_ok=True) save_path = str(models_dir / "scaler_params.pkl") print(f"SS预处理: scaler模型将保存到 {save_path}") diff --git a/src/gui/panels/step6_feature_panel.py b/src/gui/panels/step6_feature_panel.py index e42c6f7..5201c93 100644 --- a/src/gui/panels/step6_feature_panel.py +++ b/src/gui/panels/step6_feature_panel.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -Step5 面板 - 光谱提取 +Step6 面板 - 光谱特征提取 """ import os @@ -27,7 +27,7 @@ class Step6FeaturePanel(QWidget): layout = QVBoxLayout() # 标题 - title = QLabel("步骤5:训练样本光谱提取") + title = QLabel("步骤6:光谱特征提取") title.setFont(QFont("Arial", 12, QFont.Bold)) layout.addWidget(title) @@ -58,12 +58,12 @@ class Step6FeaturePanel(QWidget): "Mask Files (*.dat *.tif);;All Files (*.*)" ) layout.addWidget(self.glint_mask_file) - step5_glint_hint = QLabel( + step6_glint_hint = QLabel( "提示:独立运行本步骤时必须选择耀斑掩膜(通常为步骤2输出的 severe_glint_area.dat),用于在采样时避开耀斑像元。" ) - step5_glint_hint.setWordWrap(True) - step5_glint_hint.setStyleSheet("color: #666; font-size: 10px;") - layout.addWidget(step5_glint_hint) + step6_glint_hint.setWordWrap(True) + step6_glint_hint.setStyleSheet("color: #666; font-size: 10px;") + layout.addWidget(step6_glint_hint) # 参数设置 params_group = QGroupBox("提取参数") @@ -200,20 +200,22 @@ class Step6FeaturePanel(QWidget): else: self.output_file.set_path("") - # 5. 尝试从 Step4 界面读取已处理的水质参数 CSV 路径,自动填入本面板 + # 5. 尝试从 Step5 Clean 界面读取已处理的清洗后 CSV 路径,自动填入本面板 main_window = self.window() - if main_window and hasattr(main_window, 'step5_panel'): - step4_output_path = main_window.step5_panel.output_file.get_path() - if step4_output_path: + if main_window and hasattr(main_window, 'step5_clean_panel'): + step5_clean_output_path = main_window.step5_clean_panel.output_file.get_path() + if step5_clean_output_path: # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step4_output_path): - step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/') + if not os.path.isabs(step5_clean_output_path): + step5_clean_output_path = os.path.join( + self.work_dir or '', step5_clean_output_path + ).replace('\\', '/') existing_csv = self.csv_file.get_path() if not existing_csv or not existing_csv.strip(): - self.csv_file.set_path(step4_output_path) + self.csv_file.set_path(step5_clean_output_path) def run_step(self): - """独立运行步骤5""" + """独立运行步骤6""" # 验证输入 deglint_img_path = self.deglint_img_file.get_path() csv_path = self.csv_file.get_path() diff --git a/src/gui/water_quality_gui.py b/src/gui/water_quality_gui.py index 6df84e8..7877033 100644 --- a/src/gui/water_quality_gui.py +++ b/src/gui/water_quality_gui.py @@ -1393,9 +1393,7 @@ class WaterQualityGUI(QMainWindow): 'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), 'water_mask_path': ('step1', 'water_mask', 'water_mask_file') }, - 'step5_clean': { - 'csv_path': ('step4_sampling', 'sampling_spectra', 'csv_file') # step5 寻找 step4 的采样点 - }, + # 'step5_clean': 业务要求保持输入源独立,不自动抓取 step4_sampling 的输出;用户手动浏览导入 'step6_feature': { 'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), 'csv_path': ('step5_clean', 'processed_data', 'csv_file'), @@ -2252,21 +2250,32 @@ class WaterQualityGUI(QMainWindow): # 检查面板是否有对应的属性 if not hasattr(panel, panel_attr): continue - + file_widget = getattr(panel, panel_attr) - + + # ★ 兼容 FileSelectWidget 与原生 QLineEdit + current_text = ( + file_widget.get_path().strip() + if hasattr(file_widget, 'get_path') + else file_widget.text().strip() + ) + # 如果输入框已经有内容,跳过自动填充 - if file_widget.get_path().strip(): + if current_text: continue - + # 查找依赖步骤的输出文件 output_path = self.find_step_output(work_path, dep_step, output_type) - + if output_path and Path(output_path).exists(): - file_widget.set_path(output_path) + # ★ 兼容 FileSelectWidget 与原生 QLineEdit + if hasattr(file_widget, 'set_path'): + file_widget.set_path(str(output_path)) + else: + file_widget.setText(str(output_path)) self.log_message(f"自动填充 {step_id}.{input_field}: {output_path}", "info") filled_count += 1 - + return filled_count def get_step_panel(self, step_id): diff --git a/tests/test_interpolator_refactor.py b/tests/test_interpolator_refactor.py new file mode 100644 index 0000000..2946e35 --- /dev/null +++ b/tests/test_interpolator_refactor.py @@ -0,0 +1,411 @@ +""" +interpolator.py 多进程重构的行为测试(mock-based,无 osgeo 依赖) + +验证: +1. 静态结构:模块级函数集合、签名、新参数、_worker_dataset 全局 +2. 行为逻辑:_process_one_block 在 mock dataset 上的零像素识别 + 插值 +3. 行为逻辑:_interpolate_block_worker 通过模块全局 _worker_dataset 工作 +4. 行为逻辑:interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo) +5. 向后兼容:现有 8 个 caller 参数保留默认值不变 + +如果本机有 osgeo,可补一个真实数据集的 smoke test。 +""" +import os +import sys +import ast +import types +import unittest +from unittest.mock import MagicMock, patch + +# Ensure the project src is on path so we can import the module under test +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, "..")) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + + +INTERP_PATH = os.path.join( + PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py" +) + + +# ============================================================================= +# Static structure tests +# ============================================================================= +class TestInterpolatorStructure(unittest.TestCase): + def setUp(self): + with open(INTERP_PATH, "r", encoding="utf-8") as f: + self.src = f.read() + self.tree = ast.parse(self.src) + + def test_module_level_functions(self): + mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)} + expected = { + "interpolate_pixels", + "_interpolate_single_band", + "_normalize_interpolation_method", + "_read_water_mask_to_array", + "_init_worker", + "_interpolate_block_worker", + "_process_one_block", + "interpolate_zero_pixels_batch", + } + self.assertEqual(mod_funcs, expected) + self.assertNotIn("_process_block_with_buffer", mod_funcs) + + def test_worker_dataset_module_global(self): + mod_globals = set() + for n in self.tree.body: + if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name): + mod_globals.add(n.targets[0].id) + elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name): + mod_globals.add(n.target.id) + self.assertIn("_worker_dataset", mod_globals) + + def test_interpolate_zero_pixels_batch_signature(self): + for n in self.tree.body: + if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch": + args = [a.arg for a in n.args.args] + # Backward compat: all 8 existing params + 2 new + self.assertEqual( + args, + [ + "img_path", "interpolation_method", "output_path", + "water_mask", "deglint_dir", "callback_progress", + "block_size", "halo_size", "n_workers", "use_multiprocessing", + ], + ) + defaults = [getattr(d, "value", None) for d in n.args.defaults] + self.assertEqual( + defaults, + ["nearest", None, None, None, None, 1024, 64, None, True], + ) + return + self.fail("interpolate_zero_pixels_batch not found") + + def test_init_worker_signature(self): + for n in self.tree.body: + if isinstance(n, ast.FunctionDef) and n.name == "_init_worker": + self.assertEqual([a.arg for a in n.args.args], ["img_path"]) + return + self.fail("_init_worker not found") + + def test_worker_function_signature(self): + for n in self.tree.body: + if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker": + self.assertEqual([a.arg for a in n.args.args], ["task"]) + return + self.fail("_interpolate_block_worker not found") + + def test_process_one_block_signature(self): + for n in self.tree.body: + if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block": + args = [a.arg for a in n.args.args] + self.assertEqual( + args, + [ + "dataset", "x0", "y0", + "ey0", "ex0", "ey1", "ex1", + "row_offset", "col_offset", + "inner_h", "inner_w", + "mask_segment_ext", "method", + ], + ) + return + self.fail("_process_one_block not found") + + def test_uses_process_pool_executor(self): + self.assertIn("ProcessPoolExecutor", self.src) + self.assertIn("ProcessPoolExecutor(", self.src) + self.assertIn("initializer=_init_worker", self.src) + self.assertIn("initargs=(img_path,)", self.src) + self.assertIn("GDAL_NUM_THREADS", self.src) + + def test_dispatch_logic_present(self): + # Both serial and parallel paths should be present + self.assertIn("if effective_workers <= 1:", self.src) + self.assertIn("with ProcessPoolExecutor", self.src) + + def test_serial_path_uses_process_one_block(self): + # Serial branch should call _process_one_block directly (no pickle overhead) + self.assertIn( + "_process_one_block(\n dataset, *task\n )", + self.src, + ) + + def test_worker_path_uses_process_one_block(self): + # Worker function should also call _process_one_block (shared core) + self.assertIn("_process_one_block(", self.src) + # The worker should read from _worker_dataset, not receive dataset in task + self.assertIn("_worker_dataset", self.src) + + +# ============================================================================= +# Mocked behavior tests (no real osgeo/scipy needed) +# ============================================================================= + + +def _make_mock_band(read_data): + band = MagicMock() + band.ReadAsArray.return_value = read_data + return band + + +def _make_mock_dataset(bands_data, n_bands=None): + if n_bands is None: + n_bands = len(bands_data) + ds = MagicMock() + ds.RasterCount = n_bands + ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1]) + return ds + + +def _make_fake_module(name, attrs=None): + mod = types.ModuleType(name) + for k, v in (attrs or {}).items(): + setattr(mod, k, v) + return mod + + +def _install_fake_modules(): + """Install minimal fakes for scipy + osgeo so the module under test imports.""" + import numpy as np + + # Fake scipy + scipy = types.ModuleType("scipy") + scipy_ndimage = types.ModuleType("scipy.ndimage") + scipy_interp = types.ModuleType("scipy.interpolate") + scipy_spatial = types.ModuleType("scipy.spatial") + + class _FakeTree: + def __init__(self, coords): + self.coords = coords + def query(self, pts): + # Return index 0 for all queries + import numpy as np + n = len(pts) + return np.zeros(n, dtype=int), np.zeros(n, dtype=int) + + def _griddata(coords, values, pts, method="linear", fill_value=0.0): + # Trivial: nearest valid value + if len(coords) == 0 or len(pts) == 0: + return np.zeros(len(pts), dtype=np.float32) + # Just return first value for all queries (testing only) + return np.full(len(pts), values[0], dtype=np.float32) + + class _FakeRBF: + def __init__(self, coords, values, kernel=None): + self.first_value = values[0] + def __call__(self, pts): + return np.full(len(pts), self.first_value, dtype=np.float32) + + scipy_spatial.cKDTree = _FakeTree + scipy_interp.griddata = _griddata + scipy_interp.RBFInterpolator = _FakeRBF + + scipy.ndimage = scipy_ndimage + scipy.interpolate = scipy_interp + scipy.spatial = scipy_spatial + + # Fake osgeo.gdal with the constants the module needs + osgeo = types.ModuleType("osgeo") + gdal_mod = types.ModuleType("osgeo.gdal") + gdal_mod.GA_ReadOnly = 0 + gdal_mod.GDT_Float32 = 6 + gdal_mod.UseExceptions = MagicMock() + + def _open(path, mode): + return _make_mock_dataset([]) # empty dataset; real tests build their own + gdal_mod.Open = _open + gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock()) + gdal_mod.SetConfigOption = MagicMock() + + osgeo.gdal = gdal_mod + + sys.modules["scipy"] = scipy + sys.modules["scipy.ndimage"] = scipy_ndimage + sys.modules["scipy.interpolate"] = scipy_interp + sys.modules["scipy.spatial"] = scipy_spatial + sys.modules["osgeo"] = osgeo + sys.modules["osgeo.gdal"] = gdal_mod + + return { + "scipy": scipy, + "scipy.spatial": scipy_spatial, + "scipy.interpolate": scipy_interp, + "osgeo.gdal": gdal_mod, + } + + +class TestProcessOneBlockMocked(unittest.TestCase): + """Verify _process_one_block logic with mocked GDAL/scipy.""" + + @classmethod + def setUpClass(cls): + _install_fake_modules() + # Import the module under test after installing fakes + from src.core.algorithms.interpolation import interpolator + cls.interp = interpolator + cls.gdal = sys.modules["osgeo.gdal"] + + def _build_dataset(self, bands): + """Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h) + returns bands[b-1] (a numpy array of the requested shape). + + For simplicity we return the full band array regardless of x,y,w,h. + """ + ds = MagicMock() + ds.RasterCount = len(bands) + + def get_band(b): + band = MagicMock() + band.ReadAsArray.return_value = bands[b - 1] + return band + ds.GetRasterBand.side_effect = get_band + return ds + + def test_no_zeros_returns_inner_blocks_unchanged(self): + """If no pixels are all-zero, inner blocks should be returned as-is.""" + import numpy as np + # 2x2 image, 3 bands, all 1s (no zeros anywhere) + band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo + bands = [band, band * 2, band * 3] + ds = self._build_dataset(bands) + + inner_bands, zero_count = self.interp._process_one_block( + dataset=ds, + x0=0, y0=0, + ey0=0, ex0=0, ey1=4, ex1=4, + row_offset=0, col_offset=0, + inner_h=2, inner_w=2, + mask_segment_ext=None, + method="nearest", + ) + self.assertEqual(zero_count, 0) + self.assertEqual(len(inner_bands), 3) + for ib, expected_band in zip(inner_bands, bands): + self.assertEqual(ib.shape, (2, 2)) + np.testing.assert_array_equal(ib, expected_band[:2, :2]) + + def test_with_zero_pixel_triggers_interpolation(self): + """If a pixel is all-zero, _interpolate_single_band should be called.""" + import numpy as np + # 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s. + band1 = np.array([ + [0, 0, 1, 1], + [0, 0, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + ], dtype=np.float32) + band2 = band1 * 2 + band3 = band1 * 3 + ds = self._build_dataset([band1, band2, band3]) + + # Process the top-left 2x2 block (with halo, 4x4 covers all) + inner_bands, zero_count = self.interp._process_one_block( + dataset=ds, + x0=0, y0=0, + ey0=0, ex0=0, ey1=4, ex1=4, + row_offset=0, col_offset=0, + inner_h=2, inner_w=2, + mask_segment_ext=None, + method="nearest", + ) + self.assertGreater(zero_count, 0, "should detect zero pixels") + self.assertEqual(len(inner_bands), 3) + # Each inner band should be 2x2 + for ib in inner_bands: + self.assertEqual(ib.shape, (2, 2)) + # Our fake nearest interpolation returns valid_values[0]=1 for band1 + # So the inner block should be filled with non-zero values + self.assertTrue(np.all(inner_bands[0] > 0)) + + +class TestWorkerFunctionMocked(unittest.TestCase): + """Verify _interpolate_block_worker uses module-global _worker_dataset.""" + + @classmethod + def setUpClass(cls): + _install_fake_modules() + from src.core.algorithms.interpolation import interpolator + cls.interp = interpolator + + def test_worker_uses_module_global_dataset(self): + import numpy as np + # Set the module global + band = np.array([ + [0, 0, 1, 1], + [0, 0, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + ], dtype=np.float32) + ds = self._build_dataset([band, band * 2]) + self.interp._worker_dataset = ds + try: + task = ( + 0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1 + 0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w + None, "nearest", # mask_segment_ext, method + ) + x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task) + self.assertIsNone(error) + self.assertEqual((x0, y0), (0, 0)) + self.assertGreater(zero_count, 0) + self.assertIsNotNone(inner_bands) + self.assertEqual(len(inner_bands), 2) + finally: + self.interp._worker_dataset = None + + def test_worker_returns_error_if_dataset_uninitialized(self): + self.interp._worker_dataset = None + task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest") + x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task) + self.assertIsNotNone(error) + self.assertIn("not initialized", error) + self.assertIsNone(inner_bands) + + def _build_dataset(self, bands): + ds = MagicMock() + ds.RasterCount = len(bands) + def get_band(b): + band = MagicMock() + band.ReadAsArray.return_value = bands[b - 1] + return band + ds.GetRasterBand.side_effect = get_band + return ds + + +class TestInitWorkerMocked(unittest.TestCase): + """Verify _init_worker sets module global and config option.""" + + @classmethod + def setUpClass(cls): + _install_fake_modules() + from src.core.algorithms.interpolation import interpolator + cls.interp = interpolator + cls.gdal_mod = sys.modules["osgeo.gdal"] + + def setUp(self): + self.interp._worker_dataset = None + + def tearDown(self): + self.interp._worker_dataset = None + + def test_init_worker_opens_and_caches(self): + fake_ds = MagicMock() + with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \ + patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock: + self.interp._init_worker("/fake/path.bsq") + self.assertIs(self.interp._worker_dataset, fake_ds) + open_mock.assert_called_once_with("/fake/path.bsq", 0) + # Should have set GDAL_NUM_THREADS=1 + cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1") + + def test_init_worker_raises_if_open_fails(self): + with patch.object(self.gdal_mod, "Open", return_value=None): + with self.assertRaises(RuntimeError): + self.interp._init_worker("/bad/path.bsq") + + +if __name__ == "__main__": + unittest.main(verbosity=2)