Step3 插值算法 OOM 修复 + 多进程加速 + 全链路累积改动(14 文件)
This commit is contained in:
@ -3,8 +3,24 @@
|
||||
|
||||
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
||||
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
||||
|
||||
本模块使用多进程并行分块 IO 加速(Plan A):
|
||||
- ProcessPoolExecutor 为每个 worker 进程打开一次源影像(initializer 阶段),
|
||||
避免每块重复 gdal.Open 带来的开销(Windows 上 ~50ms/次)
|
||||
- 主进程统一负责输出文件的写入,避免多进程写锁竞争
|
||||
- 分块大小(block_size)默认 1024,内存充足可调至 2048 / 4096
|
||||
|
||||
注意:
|
||||
- GDAL Dataset / Rasterio Dataset 对象不能跨进程传递(picking 不支持),
|
||||
所以 worker 必须在 init 阶段自己独立打开源文件
|
||||
- 每个 worker 强制设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 多线程
|
||||
造成的 CPU 过订阅
|
||||
- 关闭多进程:传 ``use_multiprocessing=False`` 或 ``n_workers=1``
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from pathlib import Path
|
||||
@ -24,6 +40,9 @@ except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
_worker_dataset: Optional["gdal.Dataset"] = None
|
||||
|
||||
|
||||
def interpolate_pixels(
|
||||
image_stack: np.ndarray,
|
||||
zero_coords: np.ndarray,
|
||||
@ -52,7 +71,6 @@ def interpolate_pixels(
|
||||
height, width, n_bands = image_stack.shape
|
||||
result = image_stack.copy()
|
||||
|
||||
# 兼容中文和各种格式的method参数
|
||||
raw_method = str(interpolation_method).lower()
|
||||
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
||||
method = 'nearest'
|
||||
@ -181,39 +199,271 @@ def _interpolate_single_band(
|
||||
return np.zeros(len(zero_coords))
|
||||
|
||||
|
||||
def _normalize_interpolation_method(method: str) -> str:
|
||||
"""将中文/英文混用的插值方法名归一化为内部标准名
|
||||
|
||||
支持: 'nearest'/'邻近'/'最邻近','bilinear'/'线性'/'双线性',
|
||||
'spline'/'样条'/'rbf','kriging'/'克里金'。
|
||||
"""
|
||||
raw = str(method).lower()
|
||||
if 'nearest' in raw or '邻近' in raw or '最邻近' in raw:
|
||||
return 'nearest'
|
||||
if 'bilinear' in raw or '线性' in raw or '双线性' in raw:
|
||||
return 'bilinear'
|
||||
if 'spline' in raw or '样条' in raw or 'rbf' in raw:
|
||||
return 'spline'
|
||||
if 'kriging' in raw or '克里金' in raw:
|
||||
return 'kriging'
|
||||
return 'nearest'
|
||||
|
||||
|
||||
def _read_water_mask_to_array(
|
||||
water_mask: Optional[Union[str, np.ndarray]],
|
||||
expected_height: int,
|
||||
expected_width: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""读取水域掩膜为 numpy 数组(单波段,bool/int 均可)
|
||||
|
||||
None 或空字符串直接返回 None。形状不匹配时给出告警但不抛错,
|
||||
让调用方按"无掩膜"路径继续。
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
if isinstance(water_mask, str):
|
||||
if not water_mask.strip():
|
||||
return None
|
||||
mask_ds = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_ds is None:
|
||||
print(f" [warn] 无法打开水域掩膜 {water_mask},按无掩膜处理")
|
||||
return None
|
||||
try:
|
||||
mask_array = mask_ds.GetRasterBand(1).ReadAsArray()
|
||||
finally:
|
||||
mask_ds = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
else:
|
||||
return None
|
||||
|
||||
if mask_array.shape != (expected_height, expected_width):
|
||||
print(
|
||||
f" [warn] 水域掩膜形状 {mask_array.shape} 与影像 "
|
||||
f"({expected_height}, {expected_width}) 不匹配,按无掩膜处理"
|
||||
)
|
||||
return None
|
||||
return mask_array
|
||||
|
||||
|
||||
def _init_worker(img_path: str) -> None:
|
||||
"""ProcessPoolExecutor initializer: 每个 worker 进程只调用一次
|
||||
|
||||
在 worker 进程启动时打开源影像 dataset 并缓存在模块全局变量
|
||||
``_worker_dataset`` 中。后续所有块处理直接复用这个 dataset,
|
||||
避免每块重复 ``gdal.Open``(Windows 上约 50ms/次,100 块即 5s)。
|
||||
|
||||
同时设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 默认多线程
|
||||
造成的 CPU 过订阅。
|
||||
"""
|
||||
global _worker_dataset
|
||||
gdal.SetConfigOption('GDAL_NUM_THREADS', '1')
|
||||
if hasattr(gdal, 'UseExceptions'):
|
||||
gdal.UseExceptions()
|
||||
_worker_dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if _worker_dataset is None:
|
||||
raise RuntimeError(f"Worker failed to open source image: {img_path}")
|
||||
|
||||
|
||||
def _interpolate_block_worker(task: tuple) -> tuple:
|
||||
"""ProcessPoolExecutor worker: 处理单个块并返回结果
|
||||
|
||||
该函数必须保持模块级(可被 pickle),不持有任何外部状态——
|
||||
源 dataset 通过 ``_worker_dataset`` 模块全局变量获取。
|
||||
|
||||
Returns:
|
||||
``(x0, y0, inner_bands, zero_count, error_msg)`` 元组:
|
||||
- x0, y0: 块在影像中的写入起点
|
||||
- inner_bands: ``List[np.ndarray]``,每个元素是 (inner_h, inner_w)
|
||||
float32 数组(每个波段一个),或失败时为 None
|
||||
- zero_count: 该扩展块中识别到的零像素数(含 halo 范围)
|
||||
- error_msg: None 表示成功,str 表示错误信息
|
||||
"""
|
||||
(
|
||||
x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
) = task
|
||||
if _worker_dataset is None:
|
||||
return (x0, y0, None, 0, "Worker dataset not initialized")
|
||||
try:
|
||||
inner_bands, zero_count = _process_one_block(
|
||||
_worker_dataset, x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
)
|
||||
return (x0, y0, inner_bands, zero_count, None)
|
||||
except Exception as e:
|
||||
return (x0, y0, None, 0, str(e))
|
||||
|
||||
|
||||
def _process_one_block(
|
||||
dataset: "gdal.Dataset",
|
||||
x0: int, y0: int,
|
||||
ey0: int, ex0: int, ey1: int, ex1: int,
|
||||
row_offset: int, col_offset: int,
|
||||
inner_h: int, inner_w: int,
|
||||
mask_segment_ext: Optional[np.ndarray],
|
||||
method: str,
|
||||
) -> Tuple[List[np.ndarray], int]:
|
||||
"""处理单个扩展块(纯计算核心,dataset 显式传入)
|
||||
|
||||
串行模式和并行模式共用此函数。并行模式下 dataset 来自 worker 的
|
||||
缓存(``_worker_dataset``),串行模式下 dataset 由主函数传入。
|
||||
|
||||
Args:
|
||||
dataset: 已打开的源影像 dataset
|
||||
x0, y0: 内部块左上角(写入位置)
|
||||
ey0, ex0, ey1, ex1: 扩展块(含 halo)坐标
|
||||
row_offset, col_offset: 内部块在扩展块中的偏移
|
||||
inner_h, inner_w: 内部块尺寸
|
||||
mask_segment_ext: 扩展块对应的水域掩膜(None 表示不应用)
|
||||
method: 插值方法(已归一化)
|
||||
|
||||
Returns:
|
||||
``(inner_bands, zero_count)`` 元组:
|
||||
- inner_bands: ``List[np.ndarray]``,长度 = n_bands,每个元素形状为
|
||||
``(inner_h, inner_w)`` 的 float32 数组
|
||||
- zero_count: 扩展块中识别到的零像素数
|
||||
"""
|
||||
n_bands = dataset.RasterCount
|
||||
ext_bands: List[np.ndarray] = []
|
||||
for b in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(b)
|
||||
ext_bands.append(
|
||||
band.ReadAsArray(ex0, ey0, ex1 - ex0, ey1 - ey0).astype(np.float32)
|
||||
)
|
||||
band = None
|
||||
|
||||
try:
|
||||
ext_h, ext_w = ey1 - ey0, ex1 - ex0
|
||||
|
||||
all_zero_ext = np.ones((ext_h, ext_w), dtype=bool)
|
||||
for b_data in ext_bands:
|
||||
all_zero_ext &= (b_data == 0)
|
||||
|
||||
if mask_segment_ext is not None:
|
||||
all_zero_ext &= (mask_segment_ext > 0)
|
||||
|
||||
zero_count = int(np.sum(all_zero_ext))
|
||||
|
||||
if zero_count == 0:
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, 0
|
||||
|
||||
zero_y, zero_x = np.where(all_zero_ext)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_zero_ext
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
print(
|
||||
f" [warn] 块 (y={y0}-{y0 + inner_h}, x={x0}-{x0 + inner_w}) "
|
||||
f"无有效像素可作插值上下文,已跳过"
|
||||
)
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, zero_count
|
||||
|
||||
for b in range(n_bands):
|
||||
ext_band = ext_bands[b]
|
||||
valid_values_band = ext_band[valid_mask]
|
||||
if len(valid_values_band) == 0:
|
||||
continue
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, method
|
||||
)
|
||||
ext_band[zero_y, zero_x] = band_result
|
||||
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, zero_count
|
||||
finally:
|
||||
del ext_bands
|
||||
|
||||
|
||||
def interpolate_zero_pixels_batch(
|
||||
img_path: str,
|
||||
interpolation_method: str = 'nearest',
|
||||
output_path: Optional[str] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
deglint_dir: Optional[str] = None,
|
||||
callback_progress: Optional[callable] = None
|
||||
callback_progress: Optional[callable] = None,
|
||||
block_size: int = 1024,
|
||||
halo_size: int = 64,
|
||||
n_workers: Optional[int] = None,
|
||||
use_multiprocessing: bool = True,
|
||||
) -> Tuple[str, Optional[np.ndarray]]:
|
||||
"""
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)。
|
||||
|
||||
采用 **分块 IO + 多进程并行** 策略:
|
||||
1. 影像按 ``block_size`` × ``block_size`` 分块,每块边界外扩展
|
||||
``halo_size`` 像素作为插值上下文,避免块边缘插值退化
|
||||
2. 多进程并行(默认 ``ProcessPoolExecutor``,worker 数 = CPU 核心数)
|
||||
并发处理所有块;GDAL Dataset 不能跨进程传递,所以每个 worker
|
||||
在 ``initializer`` 阶段独立打开源文件一次并缓存
|
||||
3. 主进程按块序接收处理结果并统一写入输出文件,避免写锁竞争
|
||||
4. 该方案可彻底避免一次性读取 50 波段整景影像时的 OOM 隐患
|
||||
(50 波段 × 4000×4000 × float32 ≈ 3GB 的 np.dstack)
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组)
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline',
|
||||
'kriging' 及其中文别名('邻近'/'最邻近'/'线性'/'双线性'/'样条'/'克里金')
|
||||
output_path: 输出文件路径(如果为 None 且 deglint_dir 提供,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组),形状须与影像高宽一致
|
||||
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
||||
callback_progress: 进度回调函数
|
||||
callback_progress: 进度回调函数,签名 ``callback(msg: str)``
|
||||
block_size: 分块大小(像素),默认 1024;内存充足可调 2048/4096
|
||||
halo_size: 上下文 halo 宽度(像素),默认 64
|
||||
n_workers: 并行 worker 进程数;None = ``multiprocessing.cpu_count()``;
|
||||
传 1 等价于串行模式
|
||||
use_multiprocessing: 是否启用多进程;False 时强制串行
|
||||
|
||||
Returns:
|
||||
(output_path, interpolated_image_stack) 元组
|
||||
``(output_path, None)`` 元组。第二个值固定为 ``None``(与原版语义保留
|
||||
兼容;返回完整内存堆叠会重新引入 OOM 风险,故不再提供)。
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
# 确定输出路径
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
|
||||
method = _normalize_interpolation_method(interpolation_method)
|
||||
|
||||
# 检查文件是否已存在
|
||||
if output_path and Path(output_path).exists():
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{method}.bsq")
|
||||
if output_path is None:
|
||||
raise ValueError("output_path 和 deglint_dir 至少需要指定一个")
|
||||
|
||||
if Path(output_path).exists():
|
||||
return output_path, None
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
@ -227,94 +477,125 @@ def interpolate_zero_pixels_batch(
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
# 读取所有波段数据
|
||||
all_bands = []
|
||||
for band_idx in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(band_idx)
|
||||
band_data = band.ReadAsArray().astype(np.float32)
|
||||
all_bands.append(band_data)
|
||||
|
||||
image_stack = np.dstack(all_bands)
|
||||
|
||||
# 读取水域掩膜
|
||||
mask_array = None
|
||||
if water_mask is not None:
|
||||
if isinstance(water_mask, str):
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
|
||||
# 找出所有波段都为0的像素点
|
||||
all_bands_zero = np.all(image_stack == 0, axis=2)
|
||||
|
||||
if mask_array is not None:
|
||||
all_bands_zero = all_bands_zero & (mask_array > 0)
|
||||
|
||||
zero_pixel_count = np.sum(all_bands_zero)
|
||||
if zero_pixel_count == 0:
|
||||
# 无需插值,直接保存
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(all_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
return output_path, image_stack
|
||||
|
||||
# 获取坐标
|
||||
zero_y, zero_x = np.where(all_bands_zero)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_bands_zero
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
raise ValueError("没有有效像素可用于插值")
|
||||
|
||||
# 逐波段插值
|
||||
interpolated_bands = []
|
||||
for band_idx in range(n_bands):
|
||||
if callback_progress:
|
||||
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
|
||||
band_data = all_bands[band_idx].copy()
|
||||
valid_values_band = band_data[valid_mask]
|
||||
|
||||
if len(valid_values_band) == 0:
|
||||
interpolated_bands.append(band_data)
|
||||
continue
|
||||
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, interpolation_method
|
||||
if width <= 0 or height <= 0 or n_bands <= 0:
|
||||
raise ValueError(
|
||||
f"影像尺寸异常: width={width}, height={height}, n_bands={n_bands}"
|
||||
)
|
||||
band_data[all_bands_zero] = band_result
|
||||
interpolated_bands.append(band_data)
|
||||
|
||||
# 保存结果
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(interpolated_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
mask_array = _read_water_mask_to_array(water_mask, height, width)
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
if driver is None:
|
||||
raise RuntimeError("未找到可用的栅格驱动(ENVI / GTiff 都不存在)")
|
||||
|
||||
out_dataset = driver.Create(
|
||||
output_path, width, height, n_bands, gdal.GDT_Float32
|
||||
)
|
||||
if out_dataset is None:
|
||||
raise RuntimeError(f"无法创建输出文件: {output_path}")
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
|
||||
try:
|
||||
if not use_multiprocessing:
|
||||
effective_workers = 1
|
||||
elif n_workers is not None and n_workers >= 1:
|
||||
effective_workers = int(n_workers)
|
||||
else:
|
||||
try:
|
||||
cpu_count = multiprocessing.cpu_count() or 1
|
||||
except (NotImplementedError, OSError):
|
||||
cpu_count = 1
|
||||
effective_workers = max(1, cpu_count)
|
||||
|
||||
n_blocks_y = (height + block_size - 1) // block_size
|
||||
n_blocks_x = (width + block_size - 1) // block_size
|
||||
total_blocks = n_blocks_y * n_blocks_x
|
||||
|
||||
tasks = []
|
||||
for by in range(n_blocks_y):
|
||||
y0 = by * block_size
|
||||
y1 = min(y0 + block_size, height)
|
||||
inner_h = y1 - y0
|
||||
ey0 = max(0, y0 - halo_size)
|
||||
ey1 = min(height, y1 + halo_size)
|
||||
for bx in range(n_blocks_x):
|
||||
x0 = bx * block_size
|
||||
x1 = min(x0 + block_size, width)
|
||||
inner_w = x1 - x0
|
||||
ex0 = max(0, x0 - halo_size)
|
||||
ex1 = min(width, x1 + halo_size)
|
||||
row_offset = y0 - ey0
|
||||
col_offset = x0 - ex0
|
||||
mask_segment_ext = None
|
||||
if mask_array is not None:
|
||||
mask_segment_ext = mask_array[ey0:ey1, ex0:ex1]
|
||||
tasks.append((
|
||||
x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
))
|
||||
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"分块插值开始: 共 {total_blocks} 块 "
|
||||
f"(block_size={block_size}, halo={halo_size}, method={method}, "
|
||||
f"workers={effective_workers})"
|
||||
)
|
||||
|
||||
total_zero_pixels = 0
|
||||
|
||||
if effective_workers <= 1:
|
||||
for block_idx, task in enumerate(tasks, 1):
|
||||
x0_t, y0_t = task[0], task[1]
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"块 {block_idx}/{total_blocks} "
|
||||
f"y=[{y0_t},{y0_t + task[8]}) x=[{x0_t},{x0_t + task[9]})"
|
||||
)
|
||||
inner_bands, zero_count = _process_one_block(
|
||||
dataset, *task
|
||||
)
|
||||
for b_idx, band_data in enumerate(inner_bands):
|
||||
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
|
||||
band_data, xoff=x0_t, yoff=y0_t
|
||||
)
|
||||
total_zero_pixels += zero_count
|
||||
else:
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=effective_workers,
|
||||
initializer=_init_worker,
|
||||
initargs=(img_path,),
|
||||
) as executor:
|
||||
futures = [
|
||||
executor.submit(_interpolate_block_worker, task)
|
||||
for task in tasks
|
||||
]
|
||||
for block_idx, future in enumerate(futures, 1):
|
||||
x0_t, y0_t, inner_bands, zero_count, error = future.result()
|
||||
if error is not None:
|
||||
raise RuntimeError(
|
||||
f"块 (y={y0_t}, x={x0_t}) 处理失败: {error}"
|
||||
)
|
||||
if inner_bands is not None:
|
||||
for b_idx, band_data in enumerate(inner_bands):
|
||||
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
|
||||
band_data, xoff=x0_t, yoff=y0_t
|
||||
)
|
||||
total_zero_pixels += zero_count
|
||||
if callback_progress:
|
||||
callback_progress(f"已写入块 {block_idx}/{total_blocks}")
|
||||
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"分块插值完成: 共处理 {total_zero_pixels} 个零像素 "
|
||||
f"({total_blocks} 块,方法 {method},workers={effective_workers})"
|
||||
)
|
||||
|
||||
return output_path, None
|
||||
finally:
|
||||
out_dataset = None
|
||||
|
||||
result_stack = np.dstack(interpolated_bands)
|
||||
return output_path, result_stack
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
@ -899,7 +899,7 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
||||
if __name__ == '__main__':
|
||||
# 在这里直接设置参数
|
||||
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
|
||||
|
||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||
|
||||
@ -806,8 +806,8 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
||||
if __name__ == '__main__':
|
||||
# 在这里直接设置参数
|
||||
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\work_dir\6_Spectral_Feature_Extraction/yangdian_output.csv" # CSV格式输出文件路径
|
||||
|
||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||
|
||||
@ -315,7 +315,7 @@ def main():
|
||||
|
||||
# 示例1: 使用所有回归方法分析光谱指数
|
||||
print("\n1. 光谱指数与叶绿素a的回归分析:")
|
||||
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv")
|
||||
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\water_quality_results.csv")
|
||||
spectral_indices = ['Al10SABI','Am092Bsub']
|
||||
|
||||
results1 = analyzer.batch_single_variable_regression(
|
||||
@ -323,7 +323,7 @@ def main():
|
||||
x_columns=spectral_indices,
|
||||
y_column='Chlorophyll',
|
||||
methods='all',
|
||||
output_file=r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\spectral_indices_regression.csv'
|
||||
output_file=r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\spectral_indices_regression.csv'
|
||||
)
|
||||
|
||||
# # 示例2: 使用特定方法分析反射率波段
|
||||
@ -343,7 +343,7 @@ def main():
|
||||
best_models = analyzer.get_best_models_summary()
|
||||
if not best_models.empty:
|
||||
print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False))
|
||||
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\best_models_summary.csv', index=False)
|
||||
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\best_models_summary.csv', index=False)
|
||||
print("\n最佳模型汇总已保存到 'best_models_summary.csv'")
|
||||
#
|
||||
# def advanced_usage_example():
|
||||
|
||||
@ -246,7 +246,7 @@ def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, outp
|
||||
|
||||
if __name__ == "__main__":
|
||||
algorithm= "chl_a"
|
||||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json"
|
||||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\8_non_empirical_models\SS\SS_chl_a.json"
|
||||
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
|
||||
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
|
||||
wave_radius=5.0
|
||||
|
||||
@ -98,7 +98,7 @@ PIPELINE_STEPS: List[StepSpec] = [
|
||||
step_id="step4", method_name="step5_process_csv",
|
||||
requires=["csv_path"], produces=["processed_csv_path"],
|
||||
required_input_files=["csv_path"],
|
||||
output_file="{work_dir}/4_processed_data/processed_data.csv",
|
||||
output_file="{work_dir}/5_Data_Cleaning/processed_data.csv",
|
||||
description="CSV 异常值清洗",
|
||||
),
|
||||
StepSpec(
|
||||
@ -111,21 +111,21 @@ PIPELINE_STEPS: List[StepSpec] = [
|
||||
},
|
||||
skip_when_missing=False,
|
||||
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
|
||||
output_file="{work_dir}/5_training_spectra/training_spectra.csv",
|
||||
output_file="{work_dir}/6_Spectral_Feature_Extraction/training_spectra.csv",
|
||||
description="实测样本点光谱提取",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step7", method_name="step7_calc_indices",
|
||||
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv",
|
||||
output_file="{work_dir}/7_Water_Quality_Indices/training_spectra_indices.csv",
|
||||
description="水质参数指数计算(双轨输出:A轨宽表 + B轨单文件)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8", method_name="step8_train_ml",
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl",
|
||||
output_file="{work_dir}/8_Supervised_Model_Training/best_models.pkl",
|
||||
description="ML 建模(GridSearchCV / AutoML)",
|
||||
),
|
||||
StepSpec(
|
||||
@ -134,7 +134,7 @@ PIPELINE_STEPS: List[StepSpec] = [
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
parameter_map={"training_csv_path": "csv_path"},
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl",
|
||||
output_file="{work_dir}/8_Non_Empirical_Regression/non_empirical_models.pkl",
|
||||
description="非经验统计回归",
|
||||
),
|
||||
StepSpec(
|
||||
|
||||
@ -328,7 +328,7 @@ def train_with_automl(
|
||||
split_method = split_methods[0]
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = "./7_Supervised_Model_Training_AutoML"
|
||||
output_dir = "./8_Supervised_Model_Training_AutoML"
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
preproc_dir = out_dir / preproc
|
||||
@ -519,7 +519,7 @@ if __name__ == "__main__":
|
||||
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
|
||||
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
|
||||
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
|
||||
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
|
||||
p.add_argument("--out", default="./8_Supervised_Model_Training_AutoML")
|
||||
args = p.parse_args()
|
||||
|
||||
# 智能推断 feature_start_column 类型
|
||||
|
||||
@ -21,7 +21,7 @@ class DataPreparationStep:
|
||||
@staticmethod
|
||||
def process_csv(
|
||||
csv_path: str,
|
||||
output_dir: Union[str, Path] = "./4_processed_data",
|
||||
output_dir: Union[str, Path] = "./5_Data_Cleaning",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""处理CSV文件(筛选剔除异常值)"""
|
||||
@ -61,7 +61,7 @@ class DataPreparationStep:
|
||||
boundary_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./5_training_spectra",
|
||||
output_dir: Union[str, Path] = "./6_Spectral_Feature_Extraction",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
||||
@ -131,7 +131,7 @@ class DataPreparationStep:
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
output_dir: Union[str, Path] = "./6_water_quality_indices",
|
||||
output_dir: Union[str, Path] = "./7_Water_Quality_Indices",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> Optional[str]:
|
||||
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
||||
|
||||
@ -135,7 +135,7 @@ class ModelingStep:
|
||||
split_methods: Optional[List[str]] = None,
|
||||
cv_folds: int = 5,
|
||||
training_csv_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
|
||||
output_dir: Union[str, Path] = "./8_Supervised_Model_Training",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
) -> str:
|
||||
@ -251,7 +251,7 @@ class ModelingStep:
|
||||
if output_dir is not None:
|
||||
non_empirical_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
|
||||
non_empirical_dir = Path.cwd() / "8_Non_Empirical_Regression"
|
||||
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if preprocessing_methods is None:
|
||||
@ -430,7 +430,7 @@ def _apply_preprocessing_internal(
|
||||
|
||||
save_path = None
|
||||
if preprocess_method == "SS":
|
||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
|
||||
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training"
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_path = str(models_dir / "scaler_params.pkl")
|
||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||
|
||||
@ -259,7 +259,7 @@ class PredictionStep:
|
||||
if non_empirical_models_dir is not None:
|
||||
final_models_dir = non_empirical_models_dir
|
||||
else:
|
||||
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
|
||||
default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression")
|
||||
if Path(default_models_dir).exists():
|
||||
final_models_dir = default_models_dir
|
||||
else:
|
||||
|
||||
@ -138,11 +138,11 @@ class WaterQualityInversionPipeline:
|
||||
self.water_mask_dir = self.work_dir / "1_water_mask"
|
||||
self.glint_dir = self.work_dir / "2_Glint_Detection"
|
||||
self.deglint_dir = self.work_dir / "3_deglint"
|
||||
self.processed_data_dir = self.work_dir / "4_processed_data"
|
||||
self.training_spectra_dir = self.work_dir / "5_training_spectra"
|
||||
self.indices_dir = self.work_dir / "6_water_quality_indices"
|
||||
self.models_dir = self.work_dir / "7_Supervised_Model_Training"
|
||||
self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling"
|
||||
self.processed_data_dir = self.work_dir / "5_Data_Cleaning"
|
||||
self.training_spectra_dir = self.work_dir / "6_Spectral_Feature_Extraction"
|
||||
self.indices_dir = self.work_dir / "7_Water_Quality_Indices"
|
||||
self.models_dir = self.work_dir / "8_Supervised_Model_Training"
|
||||
self.non_empirical_models_dir = self.work_dir / "8_Non_Empirical_Regression"
|
||||
self.custom_regression_dir = self.work_dir / "13_Custom_Regression"
|
||||
self.sampling_dir = self.work_dir / "4_sampling"
|
||||
self.prediction_dir = self.work_dir / "11_12_13_predictions"
|
||||
@ -764,7 +764,7 @@ class WaterQualityInversionPipeline:
|
||||
if not spectrum_csv or not os.path.exists(spectrum_csv):
|
||||
# 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv
|
||||
fallback_candidates = []
|
||||
step5_dir = os.path.join(self.work_dir, "5_Training_Spectra")
|
||||
step5_dir = os.path.join(self.work_dir, "6_Spectral_Feature_Extraction")
|
||||
if os.path.isdir(step5_dir):
|
||||
for f in sorted(os.listdir(step5_dir)):
|
||||
if f.lower().endswith('.csv'):
|
||||
@ -2023,10 +2023,10 @@ class WaterQualityInversionPipeline:
|
||||
# 应用预处理 - 使用spectral_Preprocessing模块
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
# 为SS预处理提供scaler保存路径,保存在工作目录的7_Supervised_Model_Training中
|
||||
# 为SS预处理提供scaler保存路径,保存在工作目录的8_Supervised_Model_Training中
|
||||
save_path = None
|
||||
if preprocess_method == 'SS':
|
||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录
|
||||
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training" # 向上两级到工作目录
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_path = str(models_dir / "scaler_params.pkl")
|
||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step5 面板 - 光谱提取
|
||||
Step6 面板 - 光谱特征提取
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -27,7 +27,7 @@ class Step6FeaturePanel(QWidget):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
title = QLabel("步骤5:训练样本光谱提取")
|
||||
title = QLabel("步骤6:光谱特征提取")
|
||||
title.setFont(QFont("Arial", 12, QFont.Bold))
|
||||
layout.addWidget(title)
|
||||
|
||||
@ -58,12 +58,12 @@ class Step6FeaturePanel(QWidget):
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.glint_mask_file)
|
||||
step5_glint_hint = QLabel(
|
||||
step6_glint_hint = QLabel(
|
||||
"提示:独立运行本步骤时必须选择耀斑掩膜(通常为步骤2输出的 severe_glint_area.dat),用于在采样时避开耀斑像元。"
|
||||
)
|
||||
step5_glint_hint.setWordWrap(True)
|
||||
step5_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||
layout.addWidget(step5_glint_hint)
|
||||
step6_glint_hint.setWordWrap(True)
|
||||
step6_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||
layout.addWidget(step6_glint_hint)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("提取参数")
|
||||
@ -200,20 +200,22 @@ class Step6FeaturePanel(QWidget):
|
||||
else:
|
||||
self.output_file.set_path("")
|
||||
|
||||
# 5. 尝试从 Step4 界面读取已处理的水质参数 CSV 路径,自动填入本面板
|
||||
# 5. 尝试从 Step5 Clean 界面读取已处理的清洗后 CSV 路径,自动填入本面板
|
||||
main_window = self.window()
|
||||
if main_window and hasattr(main_window, 'step5_panel'):
|
||||
step4_output_path = main_window.step5_panel.output_file.get_path()
|
||||
if step4_output_path:
|
||||
if main_window and hasattr(main_window, 'step5_clean_panel'):
|
||||
step5_clean_output_path = main_window.step5_clean_panel.output_file.get_path()
|
||||
if step5_clean_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step4_output_path):
|
||||
step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/')
|
||||
if not os.path.isabs(step5_clean_output_path):
|
||||
step5_clean_output_path = os.path.join(
|
||||
self.work_dir or '', step5_clean_output_path
|
||||
).replace('\\', '/')
|
||||
existing_csv = self.csv_file.get_path()
|
||||
if not existing_csv or not existing_csv.strip():
|
||||
self.csv_file.set_path(step4_output_path)
|
||||
self.csv_file.set_path(step5_clean_output_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤5"""
|
||||
"""独立运行步骤6"""
|
||||
# 验证输入
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
csv_path = self.csv_file.get_path()
|
||||
|
||||
@ -1393,9 +1393,7 @@ class WaterQualityGUI(QMainWindow):
|
||||
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
|
||||
'water_mask_path': ('step1', 'water_mask', 'water_mask_file')
|
||||
},
|
||||
'step5_clean': {
|
||||
'csv_path': ('step4_sampling', 'sampling_spectra', 'csv_file') # step5 寻找 step4 的采样点
|
||||
},
|
||||
# 'step5_clean': 业务要求保持输入源独立,不自动抓取 step4_sampling 的输出;用户手动浏览导入
|
||||
'step6_feature': {
|
||||
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
|
||||
'csv_path': ('step5_clean', 'processed_data', 'csv_file'),
|
||||
@ -2252,21 +2250,32 @@ class WaterQualityGUI(QMainWindow):
|
||||
# 检查面板是否有对应的属性
|
||||
if not hasattr(panel, panel_attr):
|
||||
continue
|
||||
|
||||
|
||||
file_widget = getattr(panel, panel_attr)
|
||||
|
||||
|
||||
# ★ 兼容 FileSelectWidget 与原生 QLineEdit
|
||||
current_text = (
|
||||
file_widget.get_path().strip()
|
||||
if hasattr(file_widget, 'get_path')
|
||||
else file_widget.text().strip()
|
||||
)
|
||||
|
||||
# 如果输入框已经有内容,跳过自动填充
|
||||
if file_widget.get_path().strip():
|
||||
if current_text:
|
||||
continue
|
||||
|
||||
|
||||
# 查找依赖步骤的输出文件
|
||||
output_path = self.find_step_output(work_path, dep_step, output_type)
|
||||
|
||||
|
||||
if output_path and Path(output_path).exists():
|
||||
file_widget.set_path(output_path)
|
||||
# ★ 兼容 FileSelectWidget 与原生 QLineEdit
|
||||
if hasattr(file_widget, 'set_path'):
|
||||
file_widget.set_path(str(output_path))
|
||||
else:
|
||||
file_widget.setText(str(output_path))
|
||||
self.log_message(f"自动填充 {step_id}.{input_field}: {output_path}", "info")
|
||||
filled_count += 1
|
||||
|
||||
|
||||
return filled_count
|
||||
|
||||
def get_step_panel(self, step_id):
|
||||
|
||||
411
tests/test_interpolator_refactor.py
Normal file
411
tests/test_interpolator_refactor.py
Normal file
@ -0,0 +1,411 @@
|
||||
"""
|
||||
interpolator.py 多进程重构的行为测试(mock-based,无 osgeo 依赖)
|
||||
|
||||
验证:
|
||||
1. 静态结构:模块级函数集合、签名、新参数、_worker_dataset 全局
|
||||
2. 行为逻辑:_process_one_block 在 mock dataset 上的零像素识别 + 插值
|
||||
3. 行为逻辑:_interpolate_block_worker 通过模块全局 _worker_dataset 工作
|
||||
4. 行为逻辑:interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo)
|
||||
5. 向后兼容:现有 8 个 caller 参数保留默认值不变
|
||||
|
||||
如果本机有 osgeo,可补一个真实数据集的 smoke test。
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Ensure the project src is on path so we can import the module under test
|
||||
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
INTERP_PATH = os.path.join(
|
||||
PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Static structure tests
|
||||
# =============================================================================
|
||||
class TestInterpolatorStructure(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with open(INTERP_PATH, "r", encoding="utf-8") as f:
|
||||
self.src = f.read()
|
||||
self.tree = ast.parse(self.src)
|
||||
|
||||
def test_module_level_functions(self):
|
||||
mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)}
|
||||
expected = {
|
||||
"interpolate_pixels",
|
||||
"_interpolate_single_band",
|
||||
"_normalize_interpolation_method",
|
||||
"_read_water_mask_to_array",
|
||||
"_init_worker",
|
||||
"_interpolate_block_worker",
|
||||
"_process_one_block",
|
||||
"interpolate_zero_pixels_batch",
|
||||
}
|
||||
self.assertEqual(mod_funcs, expected)
|
||||
self.assertNotIn("_process_block_with_buffer", mod_funcs)
|
||||
|
||||
def test_worker_dataset_module_global(self):
|
||||
mod_globals = set()
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name):
|
||||
mod_globals.add(n.targets[0].id)
|
||||
elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name):
|
||||
mod_globals.add(n.target.id)
|
||||
self.assertIn("_worker_dataset", mod_globals)
|
||||
|
||||
def test_interpolate_zero_pixels_batch_signature(self):
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch":
|
||||
args = [a.arg for a in n.args.args]
|
||||
# Backward compat: all 8 existing params + 2 new
|
||||
self.assertEqual(
|
||||
args,
|
||||
[
|
||||
"img_path", "interpolation_method", "output_path",
|
||||
"water_mask", "deglint_dir", "callback_progress",
|
||||
"block_size", "halo_size", "n_workers", "use_multiprocessing",
|
||||
],
|
||||
)
|
||||
defaults = [getattr(d, "value", None) for d in n.args.defaults]
|
||||
self.assertEqual(
|
||||
defaults,
|
||||
["nearest", None, None, None, None, 1024, 64, None, True],
|
||||
)
|
||||
return
|
||||
self.fail("interpolate_zero_pixels_batch not found")
|
||||
|
||||
def test_init_worker_signature(self):
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "_init_worker":
|
||||
self.assertEqual([a.arg for a in n.args.args], ["img_path"])
|
||||
return
|
||||
self.fail("_init_worker not found")
|
||||
|
||||
def test_worker_function_signature(self):
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker":
|
||||
self.assertEqual([a.arg for a in n.args.args], ["task"])
|
||||
return
|
||||
self.fail("_interpolate_block_worker not found")
|
||||
|
||||
def test_process_one_block_signature(self):
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block":
|
||||
args = [a.arg for a in n.args.args]
|
||||
self.assertEqual(
|
||||
args,
|
||||
[
|
||||
"dataset", "x0", "y0",
|
||||
"ey0", "ex0", "ey1", "ex1",
|
||||
"row_offset", "col_offset",
|
||||
"inner_h", "inner_w",
|
||||
"mask_segment_ext", "method",
|
||||
],
|
||||
)
|
||||
return
|
||||
self.fail("_process_one_block not found")
|
||||
|
||||
def test_uses_process_pool_executor(self):
|
||||
self.assertIn("ProcessPoolExecutor", self.src)
|
||||
self.assertIn("ProcessPoolExecutor(", self.src)
|
||||
self.assertIn("initializer=_init_worker", self.src)
|
||||
self.assertIn("initargs=(img_path,)", self.src)
|
||||
self.assertIn("GDAL_NUM_THREADS", self.src)
|
||||
|
||||
def test_dispatch_logic_present(self):
|
||||
# Both serial and parallel paths should be present
|
||||
self.assertIn("if effective_workers <= 1:", self.src)
|
||||
self.assertIn("with ProcessPoolExecutor", self.src)
|
||||
|
||||
def test_serial_path_uses_process_one_block(self):
|
||||
# Serial branch should call _process_one_block directly (no pickle overhead)
|
||||
self.assertIn(
|
||||
"_process_one_block(\n dataset, *task\n )",
|
||||
self.src,
|
||||
)
|
||||
|
||||
def test_worker_path_uses_process_one_block(self):
|
||||
# Worker function should also call _process_one_block (shared core)
|
||||
self.assertIn("_process_one_block(", self.src)
|
||||
# The worker should read from _worker_dataset, not receive dataset in task
|
||||
self.assertIn("_worker_dataset", self.src)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mocked behavior tests (no real osgeo/scipy needed)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _make_mock_band(read_data):
|
||||
band = MagicMock()
|
||||
band.ReadAsArray.return_value = read_data
|
||||
return band
|
||||
|
||||
|
||||
def _make_mock_dataset(bands_data, n_bands=None):
|
||||
if n_bands is None:
|
||||
n_bands = len(bands_data)
|
||||
ds = MagicMock()
|
||||
ds.RasterCount = n_bands
|
||||
ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1])
|
||||
return ds
|
||||
|
||||
|
||||
def _make_fake_module(name, attrs=None):
|
||||
mod = types.ModuleType(name)
|
||||
for k, v in (attrs or {}).items():
|
||||
setattr(mod, k, v)
|
||||
return mod
|
||||
|
||||
|
||||
def _install_fake_modules():
|
||||
"""Install minimal fakes for scipy + osgeo so the module under test imports."""
|
||||
import numpy as np
|
||||
|
||||
# Fake scipy
|
||||
scipy = types.ModuleType("scipy")
|
||||
scipy_ndimage = types.ModuleType("scipy.ndimage")
|
||||
scipy_interp = types.ModuleType("scipy.interpolate")
|
||||
scipy_spatial = types.ModuleType("scipy.spatial")
|
||||
|
||||
class _FakeTree:
|
||||
def __init__(self, coords):
|
||||
self.coords = coords
|
||||
def query(self, pts):
|
||||
# Return index 0 for all queries
|
||||
import numpy as np
|
||||
n = len(pts)
|
||||
return np.zeros(n, dtype=int), np.zeros(n, dtype=int)
|
||||
|
||||
def _griddata(coords, values, pts, method="linear", fill_value=0.0):
|
||||
# Trivial: nearest valid value
|
||||
if len(coords) == 0 or len(pts) == 0:
|
||||
return np.zeros(len(pts), dtype=np.float32)
|
||||
# Just return first value for all queries (testing only)
|
||||
return np.full(len(pts), values[0], dtype=np.float32)
|
||||
|
||||
class _FakeRBF:
|
||||
def __init__(self, coords, values, kernel=None):
|
||||
self.first_value = values[0]
|
||||
def __call__(self, pts):
|
||||
return np.full(len(pts), self.first_value, dtype=np.float32)
|
||||
|
||||
scipy_spatial.cKDTree = _FakeTree
|
||||
scipy_interp.griddata = _griddata
|
||||
scipy_interp.RBFInterpolator = _FakeRBF
|
||||
|
||||
scipy.ndimage = scipy_ndimage
|
||||
scipy.interpolate = scipy_interp
|
||||
scipy.spatial = scipy_spatial
|
||||
|
||||
# Fake osgeo.gdal with the constants the module needs
|
||||
osgeo = types.ModuleType("osgeo")
|
||||
gdal_mod = types.ModuleType("osgeo.gdal")
|
||||
gdal_mod.GA_ReadOnly = 0
|
||||
gdal_mod.GDT_Float32 = 6
|
||||
gdal_mod.UseExceptions = MagicMock()
|
||||
|
||||
def _open(path, mode):
|
||||
return _make_mock_dataset([]) # empty dataset; real tests build their own
|
||||
gdal_mod.Open = _open
|
||||
gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock())
|
||||
gdal_mod.SetConfigOption = MagicMock()
|
||||
|
||||
osgeo.gdal = gdal_mod
|
||||
|
||||
sys.modules["scipy"] = scipy
|
||||
sys.modules["scipy.ndimage"] = scipy_ndimage
|
||||
sys.modules["scipy.interpolate"] = scipy_interp
|
||||
sys.modules["scipy.spatial"] = scipy_spatial
|
||||
sys.modules["osgeo"] = osgeo
|
||||
sys.modules["osgeo.gdal"] = gdal_mod
|
||||
|
||||
return {
|
||||
"scipy": scipy,
|
||||
"scipy.spatial": scipy_spatial,
|
||||
"scipy.interpolate": scipy_interp,
|
||||
"osgeo.gdal": gdal_mod,
|
||||
}
|
||||
|
||||
|
||||
class TestProcessOneBlockMocked(unittest.TestCase):
|
||||
"""Verify _process_one_block logic with mocked GDAL/scipy."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
_install_fake_modules()
|
||||
# Import the module under test after installing fakes
|
||||
from src.core.algorithms.interpolation import interpolator
|
||||
cls.interp = interpolator
|
||||
cls.gdal = sys.modules["osgeo.gdal"]
|
||||
|
||||
def _build_dataset(self, bands):
|
||||
"""Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h)
|
||||
returns bands[b-1] (a numpy array of the requested shape).
|
||||
|
||||
For simplicity we return the full band array regardless of x,y,w,h.
|
||||
"""
|
||||
ds = MagicMock()
|
||||
ds.RasterCount = len(bands)
|
||||
|
||||
def get_band(b):
|
||||
band = MagicMock()
|
||||
band.ReadAsArray.return_value = bands[b - 1]
|
||||
return band
|
||||
ds.GetRasterBand.side_effect = get_band
|
||||
return ds
|
||||
|
||||
def test_no_zeros_returns_inner_blocks_unchanged(self):
|
||||
"""If no pixels are all-zero, inner blocks should be returned as-is."""
|
||||
import numpy as np
|
||||
# 2x2 image, 3 bands, all 1s (no zeros anywhere)
|
||||
band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo
|
||||
bands = [band, band * 2, band * 3]
|
||||
ds = self._build_dataset(bands)
|
||||
|
||||
inner_bands, zero_count = self.interp._process_one_block(
|
||||
dataset=ds,
|
||||
x0=0, y0=0,
|
||||
ey0=0, ex0=0, ey1=4, ex1=4,
|
||||
row_offset=0, col_offset=0,
|
||||
inner_h=2, inner_w=2,
|
||||
mask_segment_ext=None,
|
||||
method="nearest",
|
||||
)
|
||||
self.assertEqual(zero_count, 0)
|
||||
self.assertEqual(len(inner_bands), 3)
|
||||
for ib, expected_band in zip(inner_bands, bands):
|
||||
self.assertEqual(ib.shape, (2, 2))
|
||||
np.testing.assert_array_equal(ib, expected_band[:2, :2])
|
||||
|
||||
def test_with_zero_pixel_triggers_interpolation(self):
|
||||
"""If a pixel is all-zero, _interpolate_single_band should be called."""
|
||||
import numpy as np
|
||||
# 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s.
|
||||
band1 = np.array([
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
], dtype=np.float32)
|
||||
band2 = band1 * 2
|
||||
band3 = band1 * 3
|
||||
ds = self._build_dataset([band1, band2, band3])
|
||||
|
||||
# Process the top-left 2x2 block (with halo, 4x4 covers all)
|
||||
inner_bands, zero_count = self.interp._process_one_block(
|
||||
dataset=ds,
|
||||
x0=0, y0=0,
|
||||
ey0=0, ex0=0, ey1=4, ex1=4,
|
||||
row_offset=0, col_offset=0,
|
||||
inner_h=2, inner_w=2,
|
||||
mask_segment_ext=None,
|
||||
method="nearest",
|
||||
)
|
||||
self.assertGreater(zero_count, 0, "should detect zero pixels")
|
||||
self.assertEqual(len(inner_bands), 3)
|
||||
# Each inner band should be 2x2
|
||||
for ib in inner_bands:
|
||||
self.assertEqual(ib.shape, (2, 2))
|
||||
# Our fake nearest interpolation returns valid_values[0]=1 for band1
|
||||
# So the inner block should be filled with non-zero values
|
||||
self.assertTrue(np.all(inner_bands[0] > 0))
|
||||
|
||||
|
||||
class TestWorkerFunctionMocked(unittest.TestCase):
|
||||
"""Verify _interpolate_block_worker uses module-global _worker_dataset."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
_install_fake_modules()
|
||||
from src.core.algorithms.interpolation import interpolator
|
||||
cls.interp = interpolator
|
||||
|
||||
def test_worker_uses_module_global_dataset(self):
|
||||
import numpy as np
|
||||
# Set the module global
|
||||
band = np.array([
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
], dtype=np.float32)
|
||||
ds = self._build_dataset([band, band * 2])
|
||||
self.interp._worker_dataset = ds
|
||||
try:
|
||||
task = (
|
||||
0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1
|
||||
0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w
|
||||
None, "nearest", # mask_segment_ext, method
|
||||
)
|
||||
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
|
||||
self.assertIsNone(error)
|
||||
self.assertEqual((x0, y0), (0, 0))
|
||||
self.assertGreater(zero_count, 0)
|
||||
self.assertIsNotNone(inner_bands)
|
||||
self.assertEqual(len(inner_bands), 2)
|
||||
finally:
|
||||
self.interp._worker_dataset = None
|
||||
|
||||
def test_worker_returns_error_if_dataset_uninitialized(self):
|
||||
self.interp._worker_dataset = None
|
||||
task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest")
|
||||
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
|
||||
self.assertIsNotNone(error)
|
||||
self.assertIn("not initialized", error)
|
||||
self.assertIsNone(inner_bands)
|
||||
|
||||
def _build_dataset(self, bands):
|
||||
ds = MagicMock()
|
||||
ds.RasterCount = len(bands)
|
||||
def get_band(b):
|
||||
band = MagicMock()
|
||||
band.ReadAsArray.return_value = bands[b - 1]
|
||||
return band
|
||||
ds.GetRasterBand.side_effect = get_band
|
||||
return ds
|
||||
|
||||
|
||||
class TestInitWorkerMocked(unittest.TestCase):
|
||||
"""Verify _init_worker sets module global and config option."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
_install_fake_modules()
|
||||
from src.core.algorithms.interpolation import interpolator
|
||||
cls.interp = interpolator
|
||||
cls.gdal_mod = sys.modules["osgeo.gdal"]
|
||||
|
||||
def setUp(self):
|
||||
self.interp._worker_dataset = None
|
||||
|
||||
def tearDown(self):
|
||||
self.interp._worker_dataset = None
|
||||
|
||||
def test_init_worker_opens_and_caches(self):
|
||||
fake_ds = MagicMock()
|
||||
with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \
|
||||
patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock:
|
||||
self.interp._init_worker("/fake/path.bsq")
|
||||
self.assertIs(self.interp._worker_dataset, fake_ds)
|
||||
open_mock.assert_called_once_with("/fake/path.bsq", 0)
|
||||
# Should have set GDAL_NUM_THREADS=1
|
||||
cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1")
|
||||
|
||||
def test_init_worker_raises_if_open_fails(self):
|
||||
with patch.object(self.gdal_mod, "Open", return_value=None):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.interp._init_worker("/bad/path.bsq")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user