Step3 插值算法 OOM 修复 + 多进程加速 + 全链路累积改动(14 文件)
This commit is contained in:
@ -3,8 +3,24 @@
|
|||||||
|
|
||||||
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
||||||
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
||||||
|
|
||||||
|
本模块使用多进程并行分块 IO 加速(Plan A):
|
||||||
|
- ProcessPoolExecutor 为每个 worker 进程打开一次源影像(initializer 阶段),
|
||||||
|
避免每块重复 gdal.Open 带来的开销(Windows 上 ~50ms/次)
|
||||||
|
- 主进程统一负责输出文件的写入,避免多进程写锁竞争
|
||||||
|
- 分块大小(block_size)默认 1024,内存充足可调至 2048 / 4096
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- GDAL Dataset / Rasterio Dataset 对象不能跨进程传递(picking 不支持),
|
||||||
|
所以 worker 必须在 init 阶段自己独立打开源文件
|
||||||
|
- 每个 worker 强制设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 多线程
|
||||||
|
造成的 CPU 过订阅
|
||||||
|
- 关闭多进程:传 ``use_multiprocessing=False`` 或 ``n_workers=1``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional, Union, Tuple, List
|
from typing import Optional, Union, Tuple, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -24,6 +40,9 @@ except ImportError:
|
|||||||
GDAL_AVAILABLE = False
|
GDAL_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
_worker_dataset: Optional["gdal.Dataset"] = None
|
||||||
|
|
||||||
|
|
||||||
def interpolate_pixels(
|
def interpolate_pixels(
|
||||||
image_stack: np.ndarray,
|
image_stack: np.ndarray,
|
||||||
zero_coords: np.ndarray,
|
zero_coords: np.ndarray,
|
||||||
@ -52,7 +71,6 @@ def interpolate_pixels(
|
|||||||
height, width, n_bands = image_stack.shape
|
height, width, n_bands = image_stack.shape
|
||||||
result = image_stack.copy()
|
result = image_stack.copy()
|
||||||
|
|
||||||
# 兼容中文和各种格式的method参数
|
|
||||||
raw_method = str(interpolation_method).lower()
|
raw_method = str(interpolation_method).lower()
|
||||||
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
||||||
method = 'nearest'
|
method = 'nearest'
|
||||||
@ -181,39 +199,271 @@ def _interpolate_single_band(
|
|||||||
return np.zeros(len(zero_coords))
|
return np.zeros(len(zero_coords))
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_interpolation_method(method: str) -> str:
|
||||||
|
"""将中文/英文混用的插值方法名归一化为内部标准名
|
||||||
|
|
||||||
|
支持: 'nearest'/'邻近'/'最邻近','bilinear'/'线性'/'双线性',
|
||||||
|
'spline'/'样条'/'rbf','kriging'/'克里金'。
|
||||||
|
"""
|
||||||
|
raw = str(method).lower()
|
||||||
|
if 'nearest' in raw or '邻近' in raw or '最邻近' in raw:
|
||||||
|
return 'nearest'
|
||||||
|
if 'bilinear' in raw or '线性' in raw or '双线性' in raw:
|
||||||
|
return 'bilinear'
|
||||||
|
if 'spline' in raw or '样条' in raw or 'rbf' in raw:
|
||||||
|
return 'spline'
|
||||||
|
if 'kriging' in raw or '克里金' in raw:
|
||||||
|
return 'kriging'
|
||||||
|
return 'nearest'
|
||||||
|
|
||||||
|
|
||||||
|
def _read_water_mask_to_array(
|
||||||
|
water_mask: Optional[Union[str, np.ndarray]],
|
||||||
|
expected_height: int,
|
||||||
|
expected_width: int,
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""读取水域掩膜为 numpy 数组(单波段,bool/int 均可)
|
||||||
|
|
||||||
|
None 或空字符串直接返回 None。形状不匹配时给出告警但不抛错,
|
||||||
|
让调用方按"无掩膜"路径继续。
|
||||||
|
"""
|
||||||
|
if water_mask is None:
|
||||||
|
return None
|
||||||
|
if isinstance(water_mask, str):
|
||||||
|
if not water_mask.strip():
|
||||||
|
return None
|
||||||
|
mask_ds = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||||
|
if mask_ds is None:
|
||||||
|
print(f" [warn] 无法打开水域掩膜 {water_mask},按无掩膜处理")
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
mask_array = mask_ds.GetRasterBand(1).ReadAsArray()
|
||||||
|
finally:
|
||||||
|
mask_ds = None
|
||||||
|
elif isinstance(water_mask, np.ndarray):
|
||||||
|
mask_array = water_mask
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if mask_array.shape != (expected_height, expected_width):
|
||||||
|
print(
|
||||||
|
f" [warn] 水域掩膜形状 {mask_array.shape} 与影像 "
|
||||||
|
f"({expected_height}, {expected_width}) 不匹配,按无掩膜处理"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return mask_array
|
||||||
|
|
||||||
|
|
||||||
|
def _init_worker(img_path: str) -> None:
|
||||||
|
"""ProcessPoolExecutor initializer: 每个 worker 进程只调用一次
|
||||||
|
|
||||||
|
在 worker 进程启动时打开源影像 dataset 并缓存在模块全局变量
|
||||||
|
``_worker_dataset`` 中。后续所有块处理直接复用这个 dataset,
|
||||||
|
避免每块重复 ``gdal.Open``(Windows 上约 50ms/次,100 块即 5s)。
|
||||||
|
|
||||||
|
同时设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 默认多线程
|
||||||
|
造成的 CPU 过订阅。
|
||||||
|
"""
|
||||||
|
global _worker_dataset
|
||||||
|
gdal.SetConfigOption('GDAL_NUM_THREADS', '1')
|
||||||
|
if hasattr(gdal, 'UseExceptions'):
|
||||||
|
gdal.UseExceptions()
|
||||||
|
_worker_dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||||
|
if _worker_dataset is None:
|
||||||
|
raise RuntimeError(f"Worker failed to open source image: {img_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def _interpolate_block_worker(task: tuple) -> tuple:
|
||||||
|
"""ProcessPoolExecutor worker: 处理单个块并返回结果
|
||||||
|
|
||||||
|
该函数必须保持模块级(可被 pickle),不持有任何外部状态——
|
||||||
|
源 dataset 通过 ``_worker_dataset`` 模块全局变量获取。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(x0, y0, inner_bands, zero_count, error_msg)`` 元组:
|
||||||
|
- x0, y0: 块在影像中的写入起点
|
||||||
|
- inner_bands: ``List[np.ndarray]``,每个元素是 (inner_h, inner_w)
|
||||||
|
float32 数组(每个波段一个),或失败时为 None
|
||||||
|
- zero_count: 该扩展块中识别到的零像素数(含 halo 范围)
|
||||||
|
- error_msg: None 表示成功,str 表示错误信息
|
||||||
|
"""
|
||||||
|
(
|
||||||
|
x0, y0, ey0, ex0, ey1, ex1,
|
||||||
|
row_offset, col_offset, inner_h, inner_w,
|
||||||
|
mask_segment_ext, method,
|
||||||
|
) = task
|
||||||
|
if _worker_dataset is None:
|
||||||
|
return (x0, y0, None, 0, "Worker dataset not initialized")
|
||||||
|
try:
|
||||||
|
inner_bands, zero_count = _process_one_block(
|
||||||
|
_worker_dataset, x0, y0, ey0, ex0, ey1, ex1,
|
||||||
|
row_offset, col_offset, inner_h, inner_w,
|
||||||
|
mask_segment_ext, method,
|
||||||
|
)
|
||||||
|
return (x0, y0, inner_bands, zero_count, None)
|
||||||
|
except Exception as e:
|
||||||
|
return (x0, y0, None, 0, str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def _process_one_block(
|
||||||
|
dataset: "gdal.Dataset",
|
||||||
|
x0: int, y0: int,
|
||||||
|
ey0: int, ex0: int, ey1: int, ex1: int,
|
||||||
|
row_offset: int, col_offset: int,
|
||||||
|
inner_h: int, inner_w: int,
|
||||||
|
mask_segment_ext: Optional[np.ndarray],
|
||||||
|
method: str,
|
||||||
|
) -> Tuple[List[np.ndarray], int]:
|
||||||
|
"""处理单个扩展块(纯计算核心,dataset 显式传入)
|
||||||
|
|
||||||
|
串行模式和并行模式共用此函数。并行模式下 dataset 来自 worker 的
|
||||||
|
缓存(``_worker_dataset``),串行模式下 dataset 由主函数传入。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: 已打开的源影像 dataset
|
||||||
|
x0, y0: 内部块左上角(写入位置)
|
||||||
|
ey0, ex0, ey1, ex1: 扩展块(含 halo)坐标
|
||||||
|
row_offset, col_offset: 内部块在扩展块中的偏移
|
||||||
|
inner_h, inner_w: 内部块尺寸
|
||||||
|
mask_segment_ext: 扩展块对应的水域掩膜(None 表示不应用)
|
||||||
|
method: 插值方法(已归一化)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(inner_bands, zero_count)`` 元组:
|
||||||
|
- inner_bands: ``List[np.ndarray]``,长度 = n_bands,每个元素形状为
|
||||||
|
``(inner_h, inner_w)`` 的 float32 数组
|
||||||
|
- zero_count: 扩展块中识别到的零像素数
|
||||||
|
"""
|
||||||
|
n_bands = dataset.RasterCount
|
||||||
|
ext_bands: List[np.ndarray] = []
|
||||||
|
for b in range(1, n_bands + 1):
|
||||||
|
band = dataset.GetRasterBand(b)
|
||||||
|
ext_bands.append(
|
||||||
|
band.ReadAsArray(ex0, ey0, ex1 - ex0, ey1 - ey0).astype(np.float32)
|
||||||
|
)
|
||||||
|
band = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
ext_h, ext_w = ey1 - ey0, ex1 - ex0
|
||||||
|
|
||||||
|
all_zero_ext = np.ones((ext_h, ext_w), dtype=bool)
|
||||||
|
for b_data in ext_bands:
|
||||||
|
all_zero_ext &= (b_data == 0)
|
||||||
|
|
||||||
|
if mask_segment_ext is not None:
|
||||||
|
all_zero_ext &= (mask_segment_ext > 0)
|
||||||
|
|
||||||
|
zero_count = int(np.sum(all_zero_ext))
|
||||||
|
|
||||||
|
if zero_count == 0:
|
||||||
|
inner_bands = [
|
||||||
|
ext_bands[b][
|
||||||
|
row_offset:row_offset + inner_h,
|
||||||
|
col_offset:col_offset + inner_w,
|
||||||
|
]
|
||||||
|
for b in range(n_bands)
|
||||||
|
]
|
||||||
|
return inner_bands, 0
|
||||||
|
|
||||||
|
zero_y, zero_x = np.where(all_zero_ext)
|
||||||
|
zero_coords = np.column_stack([zero_x, zero_y])
|
||||||
|
|
||||||
|
valid_mask = ~all_zero_ext
|
||||||
|
valid_y, valid_x = np.where(valid_mask)
|
||||||
|
valid_coords = np.column_stack([valid_x, valid_y])
|
||||||
|
|
||||||
|
if len(valid_coords) == 0:
|
||||||
|
print(
|
||||||
|
f" [warn] 块 (y={y0}-{y0 + inner_h}, x={x0}-{x0 + inner_w}) "
|
||||||
|
f"无有效像素可作插值上下文,已跳过"
|
||||||
|
)
|
||||||
|
inner_bands = [
|
||||||
|
ext_bands[b][
|
||||||
|
row_offset:row_offset + inner_h,
|
||||||
|
col_offset:col_offset + inner_w,
|
||||||
|
]
|
||||||
|
for b in range(n_bands)
|
||||||
|
]
|
||||||
|
return inner_bands, zero_count
|
||||||
|
|
||||||
|
for b in range(n_bands):
|
||||||
|
ext_band = ext_bands[b]
|
||||||
|
valid_values_band = ext_band[valid_mask]
|
||||||
|
if len(valid_values_band) == 0:
|
||||||
|
continue
|
||||||
|
band_result = _interpolate_single_band(
|
||||||
|
zero_coords, valid_coords, valid_values_band, method
|
||||||
|
)
|
||||||
|
ext_band[zero_y, zero_x] = band_result
|
||||||
|
|
||||||
|
inner_bands = [
|
||||||
|
ext_bands[b][
|
||||||
|
row_offset:row_offset + inner_h,
|
||||||
|
col_offset:col_offset + inner_w,
|
||||||
|
]
|
||||||
|
for b in range(n_bands)
|
||||||
|
]
|
||||||
|
return inner_bands, zero_count
|
||||||
|
finally:
|
||||||
|
del ext_bands
|
||||||
|
|
||||||
|
|
||||||
def interpolate_zero_pixels_batch(
|
def interpolate_zero_pixels_batch(
|
||||||
img_path: str,
|
img_path: str,
|
||||||
interpolation_method: str = 'nearest',
|
interpolation_method: str = 'nearest',
|
||||||
output_path: Optional[str] = None,
|
output_path: Optional[str] = None,
|
||||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||||
deglint_dir: Optional[str] = None,
|
deglint_dir: Optional[str] = None,
|
||||||
callback_progress: Optional[callable] = None
|
callback_progress: Optional[callable] = None,
|
||||||
|
block_size: int = 1024,
|
||||||
|
halo_size: int = 64,
|
||||||
|
n_workers: Optional[int] = None,
|
||||||
|
use_multiprocessing: bool = True,
|
||||||
) -> Tuple[str, Optional[np.ndarray]]:
|
) -> Tuple[str, Optional[np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)
|
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)。
|
||||||
|
|
||||||
|
采用 **分块 IO + 多进程并行** 策略:
|
||||||
|
1. 影像按 ``block_size`` × ``block_size`` 分块,每块边界外扩展
|
||||||
|
``halo_size`` 像素作为插值上下文,避免块边缘插值退化
|
||||||
|
2. 多进程并行(默认 ``ProcessPoolExecutor``,worker 数 = CPU 核心数)
|
||||||
|
并发处理所有块;GDAL Dataset 不能跨进程传递,所以每个 worker
|
||||||
|
在 ``initializer`` 阶段独立打开源文件一次并缓存
|
||||||
|
3. 主进程按块序接收处理结果并统一写入输出文件,避免写锁竞争
|
||||||
|
4. 该方案可彻底避免一次性读取 50 波段整景影像时的 OOM 隐患
|
||||||
|
(50 波段 × 4000×4000 × float32 ≈ 3GB 的 np.dstack)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path: 输入影像文件路径
|
img_path: 输入影像文件路径
|
||||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
|
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline',
|
||||||
output_path: 输出文件路径(如果为None,自动生成)
|
'kriging' 及其中文别名('邻近'/'最邻近'/'线性'/'双线性'/'样条'/'克里金')
|
||||||
water_mask: 水域掩膜(文件路径或数组)
|
output_path: 输出文件路径(如果为 None 且 deglint_dir 提供,自动生成)
|
||||||
|
water_mask: 水域掩膜(文件路径或数组),形状须与影像高宽一致
|
||||||
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
||||||
callback_progress: 进度回调函数
|
callback_progress: 进度回调函数,签名 ``callback(msg: str)``
|
||||||
|
block_size: 分块大小(像素),默认 1024;内存充足可调 2048/4096
|
||||||
|
halo_size: 上下文 halo 宽度(像素),默认 64
|
||||||
|
n_workers: 并行 worker 进程数;None = ``multiprocessing.cpu_count()``;
|
||||||
|
传 1 等价于串行模式
|
||||||
|
use_multiprocessing: 是否启用多进程;False 时强制串行
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(output_path, interpolated_image_stack) 元组
|
``(output_path, None)`` 元组。第二个值固定为 ``None``(与原版语义保留
|
||||||
|
兼容;返回完整内存堆叠会重新引入 OOM 风险,故不再提供)。
|
||||||
"""
|
"""
|
||||||
if not SCIPY_AVAILABLE:
|
if not SCIPY_AVAILABLE:
|
||||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||||
if not GDAL_AVAILABLE:
|
if not GDAL_AVAILABLE:
|
||||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||||
|
|
||||||
# 确定输出路径
|
method = _normalize_interpolation_method(interpolation_method)
|
||||||
if output_path is None and deglint_dir is not None:
|
|
||||||
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
|
|
||||||
|
|
||||||
# 检查文件是否已存在
|
if output_path is None and deglint_dir is not None:
|
||||||
if output_path and Path(output_path).exists():
|
output_path = str(Path(deglint_dir) / f"interpolated_{method}.bsq")
|
||||||
|
if output_path is None:
|
||||||
|
raise ValueError("output_path 和 deglint_dir 至少需要指定一个")
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
return output_path, None
|
return output_path, None
|
||||||
|
|
||||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||||
@ -227,94 +477,125 @@ def interpolate_zero_pixels_batch(
|
|||||||
geotransform = dataset.GetGeoTransform()
|
geotransform = dataset.GetGeoTransform()
|
||||||
projection = dataset.GetProjection()
|
projection = dataset.GetProjection()
|
||||||
|
|
||||||
# 读取所有波段数据
|
if width <= 0 or height <= 0 or n_bands <= 0:
|
||||||
all_bands = []
|
raise ValueError(
|
||||||
for band_idx in range(1, n_bands + 1):
|
f"影像尺寸异常: width={width}, height={height}, n_bands={n_bands}"
|
||||||
band = dataset.GetRasterBand(band_idx)
|
|
||||||
band_data = band.ReadAsArray().astype(np.float32)
|
|
||||||
all_bands.append(band_data)
|
|
||||||
|
|
||||||
image_stack = np.dstack(all_bands)
|
|
||||||
|
|
||||||
# 读取水域掩膜
|
|
||||||
mask_array = None
|
|
||||||
if water_mask is not None:
|
|
||||||
if isinstance(water_mask, str):
|
|
||||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
|
||||||
if mask_dataset:
|
|
||||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
|
||||||
mask_dataset = None
|
|
||||||
elif isinstance(water_mask, np.ndarray):
|
|
||||||
mask_array = water_mask
|
|
||||||
|
|
||||||
# 找出所有波段都为0的像素点
|
|
||||||
all_bands_zero = np.all(image_stack == 0, axis=2)
|
|
||||||
|
|
||||||
if mask_array is not None:
|
|
||||||
all_bands_zero = all_bands_zero & (mask_array > 0)
|
|
||||||
|
|
||||||
zero_pixel_count = np.sum(all_bands_zero)
|
|
||||||
if zero_pixel_count == 0:
|
|
||||||
# 无需插值,直接保存
|
|
||||||
if output_path:
|
|
||||||
driver = gdal.GetDriverByName('ENVI')
|
|
||||||
if driver is None:
|
|
||||||
driver = gdal.GetDriverByName('GTiff')
|
|
||||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
|
||||||
out_dataset.SetGeoTransform(geotransform)
|
|
||||||
out_dataset.SetProjection(projection)
|
|
||||||
for i, band_data in enumerate(all_bands):
|
|
||||||
out_band = out_dataset.GetRasterBand(i + 1)
|
|
||||||
out_band.WriteArray(band_data)
|
|
||||||
out_band.FlushCache()
|
|
||||||
out_dataset = None
|
|
||||||
return output_path, image_stack
|
|
||||||
|
|
||||||
# 获取坐标
|
|
||||||
zero_y, zero_x = np.where(all_bands_zero)
|
|
||||||
zero_coords = np.column_stack([zero_x, zero_y])
|
|
||||||
|
|
||||||
valid_mask = ~all_bands_zero
|
|
||||||
valid_y, valid_x = np.where(valid_mask)
|
|
||||||
valid_coords = np.column_stack([valid_x, valid_y])
|
|
||||||
|
|
||||||
if len(valid_coords) == 0:
|
|
||||||
raise ValueError("没有有效像素可用于插值")
|
|
||||||
|
|
||||||
# 逐波段插值
|
|
||||||
interpolated_bands = []
|
|
||||||
for band_idx in range(n_bands):
|
|
||||||
if callback_progress:
|
|
||||||
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
|
|
||||||
band_data = all_bands[band_idx].copy()
|
|
||||||
valid_values_band = band_data[valid_mask]
|
|
||||||
|
|
||||||
if len(valid_values_band) == 0:
|
|
||||||
interpolated_bands.append(band_data)
|
|
||||||
continue
|
|
||||||
|
|
||||||
band_result = _interpolate_single_band(
|
|
||||||
zero_coords, valid_coords, valid_values_band, interpolation_method
|
|
||||||
)
|
)
|
||||||
band_data[all_bands_zero] = band_result
|
|
||||||
interpolated_bands.append(band_data)
|
|
||||||
|
|
||||||
# 保存结果
|
mask_array = _read_water_mask_to_array(water_mask, height, width)
|
||||||
if output_path:
|
|
||||||
driver = gdal.GetDriverByName('ENVI')
|
driver = gdal.GetDriverByName('ENVI')
|
||||||
if driver is None:
|
if driver is None:
|
||||||
driver = gdal.GetDriverByName('GTiff')
|
driver = gdal.GetDriverByName('GTiff')
|
||||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
if driver is None:
|
||||||
out_dataset.SetGeoTransform(geotransform)
|
raise RuntimeError("未找到可用的栅格驱动(ENVI / GTiff 都不存在)")
|
||||||
out_dataset.SetProjection(projection)
|
|
||||||
for i, band_data in enumerate(interpolated_bands):
|
out_dataset = driver.Create(
|
||||||
out_band = out_dataset.GetRasterBand(i + 1)
|
output_path, width, height, n_bands, gdal.GDT_Float32
|
||||||
out_band.WriteArray(band_data)
|
)
|
||||||
out_band.FlushCache()
|
if out_dataset is None:
|
||||||
|
raise RuntimeError(f"无法创建输出文件: {output_path}")
|
||||||
|
out_dataset.SetGeoTransform(geotransform)
|
||||||
|
out_dataset.SetProjection(projection)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not use_multiprocessing:
|
||||||
|
effective_workers = 1
|
||||||
|
elif n_workers is not None and n_workers >= 1:
|
||||||
|
effective_workers = int(n_workers)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
cpu_count = multiprocessing.cpu_count() or 1
|
||||||
|
except (NotImplementedError, OSError):
|
||||||
|
cpu_count = 1
|
||||||
|
effective_workers = max(1, cpu_count)
|
||||||
|
|
||||||
|
n_blocks_y = (height + block_size - 1) // block_size
|
||||||
|
n_blocks_x = (width + block_size - 1) // block_size
|
||||||
|
total_blocks = n_blocks_y * n_blocks_x
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for by in range(n_blocks_y):
|
||||||
|
y0 = by * block_size
|
||||||
|
y1 = min(y0 + block_size, height)
|
||||||
|
inner_h = y1 - y0
|
||||||
|
ey0 = max(0, y0 - halo_size)
|
||||||
|
ey1 = min(height, y1 + halo_size)
|
||||||
|
for bx in range(n_blocks_x):
|
||||||
|
x0 = bx * block_size
|
||||||
|
x1 = min(x0 + block_size, width)
|
||||||
|
inner_w = x1 - x0
|
||||||
|
ex0 = max(0, x0 - halo_size)
|
||||||
|
ex1 = min(width, x1 + halo_size)
|
||||||
|
row_offset = y0 - ey0
|
||||||
|
col_offset = x0 - ex0
|
||||||
|
mask_segment_ext = None
|
||||||
|
if mask_array is not None:
|
||||||
|
mask_segment_ext = mask_array[ey0:ey1, ex0:ex1]
|
||||||
|
tasks.append((
|
||||||
|
x0, y0, ey0, ex0, ey1, ex1,
|
||||||
|
row_offset, col_offset, inner_h, inner_w,
|
||||||
|
mask_segment_ext, method,
|
||||||
|
))
|
||||||
|
|
||||||
|
if callback_progress:
|
||||||
|
callback_progress(
|
||||||
|
f"分块插值开始: 共 {total_blocks} 块 "
|
||||||
|
f"(block_size={block_size}, halo={halo_size}, method={method}, "
|
||||||
|
f"workers={effective_workers})"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_zero_pixels = 0
|
||||||
|
|
||||||
|
if effective_workers <= 1:
|
||||||
|
for block_idx, task in enumerate(tasks, 1):
|
||||||
|
x0_t, y0_t = task[0], task[1]
|
||||||
|
if callback_progress:
|
||||||
|
callback_progress(
|
||||||
|
f"块 {block_idx}/{total_blocks} "
|
||||||
|
f"y=[{y0_t},{y0_t + task[8]}) x=[{x0_t},{x0_t + task[9]})"
|
||||||
|
)
|
||||||
|
inner_bands, zero_count = _process_one_block(
|
||||||
|
dataset, *task
|
||||||
|
)
|
||||||
|
for b_idx, band_data in enumerate(inner_bands):
|
||||||
|
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
|
||||||
|
band_data, xoff=x0_t, yoff=y0_t
|
||||||
|
)
|
||||||
|
total_zero_pixels += zero_count
|
||||||
|
else:
|
||||||
|
with ProcessPoolExecutor(
|
||||||
|
max_workers=effective_workers,
|
||||||
|
initializer=_init_worker,
|
||||||
|
initargs=(img_path,),
|
||||||
|
) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(_interpolate_block_worker, task)
|
||||||
|
for task in tasks
|
||||||
|
]
|
||||||
|
for block_idx, future in enumerate(futures, 1):
|
||||||
|
x0_t, y0_t, inner_bands, zero_count, error = future.result()
|
||||||
|
if error is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"块 (y={y0_t}, x={x0_t}) 处理失败: {error}"
|
||||||
|
)
|
||||||
|
if inner_bands is not None:
|
||||||
|
for b_idx, band_data in enumerate(inner_bands):
|
||||||
|
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
|
||||||
|
band_data, xoff=x0_t, yoff=y0_t
|
||||||
|
)
|
||||||
|
total_zero_pixels += zero_count
|
||||||
|
if callback_progress:
|
||||||
|
callback_progress(f"已写入块 {block_idx}/{total_blocks}")
|
||||||
|
|
||||||
|
if callback_progress:
|
||||||
|
callback_progress(
|
||||||
|
f"分块插值完成: 共处理 {total_zero_pixels} 个零像素 "
|
||||||
|
f"({total_blocks} 块,方法 {method},workers={effective_workers})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_path, None
|
||||||
|
finally:
|
||||||
out_dataset = None
|
out_dataset = None
|
||||||
|
|
||||||
result_stack = np.dstack(interpolated_bands)
|
|
||||||
return output_path, result_stack
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
dataset = None
|
dataset = None
|
||||||
|
|||||||
@ -899,7 +899,7 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 在这里直接设置参数
|
# 在这里直接设置参数
|
||||||
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
|
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
|
||||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||||
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
|
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
|
||||||
|
|
||||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||||
|
|||||||
@ -806,8 +806,8 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 在这里直接设置参数
|
# 在这里直接设置参数
|
||||||
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
|
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
|
||||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||||
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径
|
output_path = r"E:\code\WQ\封装\work_dir\6_Spectral_Feature_Extraction/yangdian_output.csv" # CSV格式输出文件路径
|
||||||
|
|
||||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||||
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||||
|
|||||||
@ -315,7 +315,7 @@ def main():
|
|||||||
|
|
||||||
# 示例1: 使用所有回归方法分析光谱指数
|
# 示例1: 使用所有回归方法分析光谱指数
|
||||||
print("\n1. 光谱指数与叶绿素a的回归分析:")
|
print("\n1. 光谱指数与叶绿素a的回归分析:")
|
||||||
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv")
|
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\water_quality_results.csv")
|
||||||
spectral_indices = ['Al10SABI','Am092Bsub']
|
spectral_indices = ['Al10SABI','Am092Bsub']
|
||||||
|
|
||||||
results1 = analyzer.batch_single_variable_regression(
|
results1 = analyzer.batch_single_variable_regression(
|
||||||
@ -323,7 +323,7 @@ def main():
|
|||||||
x_columns=spectral_indices,
|
x_columns=spectral_indices,
|
||||||
y_column='Chlorophyll',
|
y_column='Chlorophyll',
|
||||||
methods='all',
|
methods='all',
|
||||||
output_file=r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\spectral_indices_regression.csv'
|
output_file=r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\spectral_indices_regression.csv'
|
||||||
)
|
)
|
||||||
|
|
||||||
# # 示例2: 使用特定方法分析反射率波段
|
# # 示例2: 使用特定方法分析反射率波段
|
||||||
@ -343,7 +343,7 @@ def main():
|
|||||||
best_models = analyzer.get_best_models_summary()
|
best_models = analyzer.get_best_models_summary()
|
||||||
if not best_models.empty:
|
if not best_models.empty:
|
||||||
print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False))
|
print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False))
|
||||||
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\best_models_summary.csv', index=False)
|
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\best_models_summary.csv', index=False)
|
||||||
print("\n最佳模型汇总已保存到 'best_models_summary.csv'")
|
print("\n最佳模型汇总已保存到 'best_models_summary.csv'")
|
||||||
#
|
#
|
||||||
# def advanced_usage_example():
|
# def advanced_usage_example():
|
||||||
|
|||||||
@ -246,7 +246,7 @@ def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, outp
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
algorithm= "chl_a"
|
algorithm= "chl_a"
|
||||||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json"
|
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\8_non_empirical_models\SS\SS_chl_a.json"
|
||||||
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
|
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
|
||||||
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
|
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
|
||||||
wave_radius=5.0
|
wave_radius=5.0
|
||||||
|
|||||||
@ -98,7 +98,7 @@ PIPELINE_STEPS: List[StepSpec] = [
|
|||||||
step_id="step4", method_name="step5_process_csv",
|
step_id="step4", method_name="step5_process_csv",
|
||||||
requires=["csv_path"], produces=["processed_csv_path"],
|
requires=["csv_path"], produces=["processed_csv_path"],
|
||||||
required_input_files=["csv_path"],
|
required_input_files=["csv_path"],
|
||||||
output_file="{work_dir}/4_processed_data/processed_data.csv",
|
output_file="{work_dir}/5_Data_Cleaning/processed_data.csv",
|
||||||
description="CSV 异常值清洗",
|
description="CSV 异常值清洗",
|
||||||
),
|
),
|
||||||
StepSpec(
|
StepSpec(
|
||||||
@ -111,21 +111,21 @@ PIPELINE_STEPS: List[StepSpec] = [
|
|||||||
},
|
},
|
||||||
skip_when_missing=False,
|
skip_when_missing=False,
|
||||||
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
|
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
|
||||||
output_file="{work_dir}/5_training_spectra/training_spectra.csv",
|
output_file="{work_dir}/6_Spectral_Feature_Extraction/training_spectra.csv",
|
||||||
description="实测样本点光谱提取",
|
description="实测样本点光谱提取",
|
||||||
),
|
),
|
||||||
StepSpec(
|
StepSpec(
|
||||||
step_id="step7", method_name="step7_calc_indices",
|
step_id="step7", method_name="step7_calc_indices",
|
||||||
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
|
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
|
||||||
required_input_files=["training_csv_path"],
|
required_input_files=["training_csv_path"],
|
||||||
output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv",
|
output_file="{work_dir}/7_Water_Quality_Indices/training_spectra_indices.csv",
|
||||||
description="水质参数指数计算(双轨输出:A轨宽表 + B轨单文件)",
|
description="水质参数指数计算(双轨输出:A轨宽表 + B轨单文件)",
|
||||||
),
|
),
|
||||||
StepSpec(
|
StepSpec(
|
||||||
step_id="step8", method_name="step8_train_ml",
|
step_id="step8", method_name="step8_train_ml",
|
||||||
requires=["training_csv_path"], produces=["models_dir"],
|
requires=["training_csv_path"], produces=["models_dir"],
|
||||||
required_input_files=["training_csv_path"],
|
required_input_files=["training_csv_path"],
|
||||||
output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl",
|
output_file="{work_dir}/8_Supervised_Model_Training/best_models.pkl",
|
||||||
description="ML 建模(GridSearchCV / AutoML)",
|
description="ML 建模(GridSearchCV / AutoML)",
|
||||||
),
|
),
|
||||||
StepSpec(
|
StepSpec(
|
||||||
@ -134,7 +134,7 @@ PIPELINE_STEPS: List[StepSpec] = [
|
|||||||
requires=["training_csv_path"], produces=["models_dir"],
|
requires=["training_csv_path"], produces=["models_dir"],
|
||||||
parameter_map={"training_csv_path": "csv_path"},
|
parameter_map={"training_csv_path": "csv_path"},
|
||||||
required_input_files=["training_csv_path"],
|
required_input_files=["training_csv_path"],
|
||||||
output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl",
|
output_file="{work_dir}/8_Non_Empirical_Regression/non_empirical_models.pkl",
|
||||||
description="非经验统计回归",
|
description="非经验统计回归",
|
||||||
),
|
),
|
||||||
StepSpec(
|
StepSpec(
|
||||||
|
|||||||
@ -328,7 +328,7 @@ def train_with_automl(
|
|||||||
split_method = split_methods[0]
|
split_method = split_methods[0]
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = "./7_Supervised_Model_Training_AutoML"
|
output_dir = "./8_Supervised_Model_Training_AutoML"
|
||||||
out_dir = Path(output_dir)
|
out_dir = Path(output_dir)
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
preproc_dir = out_dir / preproc
|
preproc_dir = out_dir / preproc
|
||||||
@ -519,7 +519,7 @@ if __name__ == "__main__":
|
|||||||
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
|
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
|
||||||
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
|
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
|
||||||
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
|
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
|
||||||
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
|
p.add_argument("--out", default="./8_Supervised_Model_Training_AutoML")
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
# 智能推断 feature_start_column 类型
|
# 智能推断 feature_start_column 类型
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class DataPreparationStep:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def process_csv(
|
def process_csv(
|
||||||
csv_path: str,
|
csv_path: str,
|
||||||
output_dir: Union[str, Path] = "./4_processed_data",
|
output_dir: Union[str, Path] = "./5_Data_Cleaning",
|
||||||
callback: Optional[Callable] = None,
|
callback: Optional[Callable] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""处理CSV文件(筛选剔除异常值)"""
|
"""处理CSV文件(筛选剔除异常值)"""
|
||||||
@ -61,7 +61,7 @@ class DataPreparationStep:
|
|||||||
boundary_path: Optional[str] = None,
|
boundary_path: Optional[str] = None,
|
||||||
glint_mask_path: Optional[str] = None,
|
glint_mask_path: Optional[str] = None,
|
||||||
water_mask_path: Optional[str] = None,
|
water_mask_path: Optional[str] = None,
|
||||||
output_dir: Union[str, Path] = "./5_training_spectra",
|
output_dir: Union[str, Path] = "./6_Spectral_Feature_Extraction",
|
||||||
callback: Optional[Callable] = None,
|
callback: Optional[Callable] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
||||||
@ -131,7 +131,7 @@ class DataPreparationStep:
|
|||||||
formula_names: Optional[List[str]] = None,
|
formula_names: Optional[List[str]] = None,
|
||||||
output_file: Optional[str] = None,
|
output_file: Optional[str] = None,
|
||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
output_dir: Union[str, Path] = "./6_water_quality_indices",
|
output_dir: Union[str, Path] = "./7_Water_Quality_Indices",
|
||||||
callback: Optional[Callable] = None,
|
callback: Optional[Callable] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
||||||
|
|||||||
@ -135,7 +135,7 @@ class ModelingStep:
|
|||||||
split_methods: Optional[List[str]] = None,
|
split_methods: Optional[List[str]] = None,
|
||||||
cv_folds: int = 5,
|
cv_folds: int = 5,
|
||||||
training_csv_path: Optional[str] = None,
|
training_csv_path: Optional[str] = None,
|
||||||
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
|
output_dir: Union[str, Path] = "./8_Supervised_Model_Training",
|
||||||
callback: Optional[Callable] = None,
|
callback: Optional[Callable] = None,
|
||||||
_report_generator=None,
|
_report_generator=None,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -251,7 +251,7 @@ class ModelingStep:
|
|||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
non_empirical_dir = Path(output_dir)
|
non_empirical_dir = Path(output_dir)
|
||||||
else:
|
else:
|
||||||
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
|
non_empirical_dir = Path.cwd() / "8_Non_Empirical_Regression"
|
||||||
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if preprocessing_methods is None:
|
if preprocessing_methods is None:
|
||||||
@ -430,7 +430,7 @@ def _apply_preprocessing_internal(
|
|||||||
|
|
||||||
save_path = None
|
save_path = None
|
||||||
if preprocess_method == "SS":
|
if preprocess_method == "SS":
|
||||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
|
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training"
|
||||||
models_dir.mkdir(parents=True, exist_ok=True)
|
models_dir.mkdir(parents=True, exist_ok=True)
|
||||||
save_path = str(models_dir / "scaler_params.pkl")
|
save_path = str(models_dir / "scaler_params.pkl")
|
||||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||||
|
|||||||
@ -259,7 +259,7 @@ class PredictionStep:
|
|||||||
if non_empirical_models_dir is not None:
|
if non_empirical_models_dir is not None:
|
||||||
final_models_dir = non_empirical_models_dir
|
final_models_dir = non_empirical_models_dir
|
||||||
else:
|
else:
|
||||||
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
|
default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression")
|
||||||
if Path(default_models_dir).exists():
|
if Path(default_models_dir).exists():
|
||||||
final_models_dir = default_models_dir
|
final_models_dir = default_models_dir
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -138,11 +138,11 @@ class WaterQualityInversionPipeline:
|
|||||||
self.water_mask_dir = self.work_dir / "1_water_mask"
|
self.water_mask_dir = self.work_dir / "1_water_mask"
|
||||||
self.glint_dir = self.work_dir / "2_Glint_Detection"
|
self.glint_dir = self.work_dir / "2_Glint_Detection"
|
||||||
self.deglint_dir = self.work_dir / "3_deglint"
|
self.deglint_dir = self.work_dir / "3_deglint"
|
||||||
self.processed_data_dir = self.work_dir / "4_processed_data"
|
self.processed_data_dir = self.work_dir / "5_Data_Cleaning"
|
||||||
self.training_spectra_dir = self.work_dir / "5_training_spectra"
|
self.training_spectra_dir = self.work_dir / "6_Spectral_Feature_Extraction"
|
||||||
self.indices_dir = self.work_dir / "6_water_quality_indices"
|
self.indices_dir = self.work_dir / "7_Water_Quality_Indices"
|
||||||
self.models_dir = self.work_dir / "7_Supervised_Model_Training"
|
self.models_dir = self.work_dir / "8_Supervised_Model_Training"
|
||||||
self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling"
|
self.non_empirical_models_dir = self.work_dir / "8_Non_Empirical_Regression"
|
||||||
self.custom_regression_dir = self.work_dir / "13_Custom_Regression"
|
self.custom_regression_dir = self.work_dir / "13_Custom_Regression"
|
||||||
self.sampling_dir = self.work_dir / "4_sampling"
|
self.sampling_dir = self.work_dir / "4_sampling"
|
||||||
self.prediction_dir = self.work_dir / "11_12_13_predictions"
|
self.prediction_dir = self.work_dir / "11_12_13_predictions"
|
||||||
@ -764,7 +764,7 @@ class WaterQualityInversionPipeline:
|
|||||||
if not spectrum_csv or not os.path.exists(spectrum_csv):
|
if not spectrum_csv or not os.path.exists(spectrum_csv):
|
||||||
# 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv
|
# 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv
|
||||||
fallback_candidates = []
|
fallback_candidates = []
|
||||||
step5_dir = os.path.join(self.work_dir, "5_Training_Spectra")
|
step5_dir = os.path.join(self.work_dir, "6_Spectral_Feature_Extraction")
|
||||||
if os.path.isdir(step5_dir):
|
if os.path.isdir(step5_dir):
|
||||||
for f in sorted(os.listdir(step5_dir)):
|
for f in sorted(os.listdir(step5_dir)):
|
||||||
if f.lower().endswith('.csv'):
|
if f.lower().endswith('.csv'):
|
||||||
@ -2023,10 +2023,10 @@ class WaterQualityInversionPipeline:
|
|||||||
# 应用预处理 - 使用spectral_Preprocessing模块
|
# 应用预处理 - 使用spectral_Preprocessing模块
|
||||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||||
|
|
||||||
# 为SS预处理提供scaler保存路径,保存在工作目录的7_Supervised_Model_Training中
|
# 为SS预处理提供scaler保存路径,保存在工作目录的8_Supervised_Model_Training中
|
||||||
save_path = None
|
save_path = None
|
||||||
if preprocess_method == 'SS':
|
if preprocess_method == 'SS':
|
||||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录
|
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training" # 向上两级到工作目录
|
||||||
models_dir.mkdir(parents=True, exist_ok=True)
|
models_dir.mkdir(parents=True, exist_ok=True)
|
||||||
save_path = str(models_dir / "scaler_params.pkl")
|
save_path = str(models_dir / "scaler_params.pkl")
|
||||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
Step5 面板 - 光谱提取
|
Step6 面板 - 光谱特征提取
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -27,7 +27,7 @@ class Step6FeaturePanel(QWidget):
|
|||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
# 标题
|
# 标题
|
||||||
title = QLabel("步骤5:训练样本光谱提取")
|
title = QLabel("步骤6:光谱特征提取")
|
||||||
title.setFont(QFont("Arial", 12, QFont.Bold))
|
title.setFont(QFont("Arial", 12, QFont.Bold))
|
||||||
layout.addWidget(title)
|
layout.addWidget(title)
|
||||||
|
|
||||||
@ -58,12 +58,12 @@ class Step6FeaturePanel(QWidget):
|
|||||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||||
)
|
)
|
||||||
layout.addWidget(self.glint_mask_file)
|
layout.addWidget(self.glint_mask_file)
|
||||||
step5_glint_hint = QLabel(
|
step6_glint_hint = QLabel(
|
||||||
"提示:独立运行本步骤时必须选择耀斑掩膜(通常为步骤2输出的 severe_glint_area.dat),用于在采样时避开耀斑像元。"
|
"提示:独立运行本步骤时必须选择耀斑掩膜(通常为步骤2输出的 severe_glint_area.dat),用于在采样时避开耀斑像元。"
|
||||||
)
|
)
|
||||||
step5_glint_hint.setWordWrap(True)
|
step6_glint_hint.setWordWrap(True)
|
||||||
step5_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
|
step6_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||||
layout.addWidget(step5_glint_hint)
|
layout.addWidget(step6_glint_hint)
|
||||||
|
|
||||||
# 参数设置
|
# 参数设置
|
||||||
params_group = QGroupBox("提取参数")
|
params_group = QGroupBox("提取参数")
|
||||||
@ -200,20 +200,22 @@ class Step6FeaturePanel(QWidget):
|
|||||||
else:
|
else:
|
||||||
self.output_file.set_path("")
|
self.output_file.set_path("")
|
||||||
|
|
||||||
# 5. 尝试从 Step4 界面读取已处理的水质参数 CSV 路径,自动填入本面板
|
# 5. 尝试从 Step5 Clean 界面读取已处理的清洗后 CSV 路径,自动填入本面板
|
||||||
main_window = self.window()
|
main_window = self.window()
|
||||||
if main_window and hasattr(main_window, 'step5_panel'):
|
if main_window and hasattr(main_window, 'step5_clean_panel'):
|
||||||
step4_output_path = main_window.step5_panel.output_file.get_path()
|
step5_clean_output_path = main_window.step5_clean_panel.output_file.get_path()
|
||||||
if step4_output_path:
|
if step5_clean_output_path:
|
||||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||||
if not os.path.isabs(step4_output_path):
|
if not os.path.isabs(step5_clean_output_path):
|
||||||
step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/')
|
step5_clean_output_path = os.path.join(
|
||||||
|
self.work_dir or '', step5_clean_output_path
|
||||||
|
).replace('\\', '/')
|
||||||
existing_csv = self.csv_file.get_path()
|
existing_csv = self.csv_file.get_path()
|
||||||
if not existing_csv or not existing_csv.strip():
|
if not existing_csv or not existing_csv.strip():
|
||||||
self.csv_file.set_path(step4_output_path)
|
self.csv_file.set_path(step5_clean_output_path)
|
||||||
|
|
||||||
def run_step(self):
|
def run_step(self):
|
||||||
"""独立运行步骤5"""
|
"""独立运行步骤6"""
|
||||||
# 验证输入
|
# 验证输入
|
||||||
deglint_img_path = self.deglint_img_file.get_path()
|
deglint_img_path = self.deglint_img_file.get_path()
|
||||||
csv_path = self.csv_file.get_path()
|
csv_path = self.csv_file.get_path()
|
||||||
|
|||||||
@ -1393,9 +1393,7 @@ class WaterQualityGUI(QMainWindow):
|
|||||||
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
|
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
|
||||||
'water_mask_path': ('step1', 'water_mask', 'water_mask_file')
|
'water_mask_path': ('step1', 'water_mask', 'water_mask_file')
|
||||||
},
|
},
|
||||||
'step5_clean': {
|
# 'step5_clean': 业务要求保持输入源独立,不自动抓取 step4_sampling 的输出;用户手动浏览导入
|
||||||
'csv_path': ('step4_sampling', 'sampling_spectra', 'csv_file') # step5 寻找 step4 的采样点
|
|
||||||
},
|
|
||||||
'step6_feature': {
|
'step6_feature': {
|
||||||
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
|
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'),
|
||||||
'csv_path': ('step5_clean', 'processed_data', 'csv_file'),
|
'csv_path': ('step5_clean', 'processed_data', 'csv_file'),
|
||||||
@ -2255,15 +2253,26 @@ class WaterQualityGUI(QMainWindow):
|
|||||||
|
|
||||||
file_widget = getattr(panel, panel_attr)
|
file_widget = getattr(panel, panel_attr)
|
||||||
|
|
||||||
|
# ★ 兼容 FileSelectWidget 与原生 QLineEdit
|
||||||
|
current_text = (
|
||||||
|
file_widget.get_path().strip()
|
||||||
|
if hasattr(file_widget, 'get_path')
|
||||||
|
else file_widget.text().strip()
|
||||||
|
)
|
||||||
|
|
||||||
# 如果输入框已经有内容,跳过自动填充
|
# 如果输入框已经有内容,跳过自动填充
|
||||||
if file_widget.get_path().strip():
|
if current_text:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 查找依赖步骤的输出文件
|
# 查找依赖步骤的输出文件
|
||||||
output_path = self.find_step_output(work_path, dep_step, output_type)
|
output_path = self.find_step_output(work_path, dep_step, output_type)
|
||||||
|
|
||||||
if output_path and Path(output_path).exists():
|
if output_path and Path(output_path).exists():
|
||||||
file_widget.set_path(output_path)
|
# ★ 兼容 FileSelectWidget 与原生 QLineEdit
|
||||||
|
if hasattr(file_widget, 'set_path'):
|
||||||
|
file_widget.set_path(str(output_path))
|
||||||
|
else:
|
||||||
|
file_widget.setText(str(output_path))
|
||||||
self.log_message(f"自动填充 {step_id}.{input_field}: {output_path}", "info")
|
self.log_message(f"自动填充 {step_id}.{input_field}: {output_path}", "info")
|
||||||
filled_count += 1
|
filled_count += 1
|
||||||
|
|
||||||
|
|||||||
411
tests/test_interpolator_refactor.py
Normal file
411
tests/test_interpolator_refactor.py
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
"""
|
||||||
|
interpolator.py 多进程重构的行为测试(mock-based,无 osgeo 依赖)
|
||||||
|
|
||||||
|
验证:
|
||||||
|
1. 静态结构:模块级函数集合、签名、新参数、_worker_dataset 全局
|
||||||
|
2. 行为逻辑:_process_one_block 在 mock dataset 上的零像素识别 + 插值
|
||||||
|
3. 行为逻辑:_interpolate_block_worker 通过模块全局 _worker_dataset 工作
|
||||||
|
4. 行为逻辑:interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo)
|
||||||
|
5. 向后兼容:现有 8 个 caller 参数保留默认值不变
|
||||||
|
|
||||||
|
如果本机有 osgeo,可补一个真实数据集的 smoke test。
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import ast
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
# Ensure the project src is on path so we can import the module under test
|
||||||
|
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
|
||||||
|
if PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, PROJECT_ROOT)
|
||||||
|
|
||||||
|
|
||||||
|
INTERP_PATH = os.path.join(
|
||||||
|
PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Static structure tests
|
||||||
|
# =============================================================================
|
||||||
|
class TestInterpolatorStructure(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
with open(INTERP_PATH, "r", encoding="utf-8") as f:
|
||||||
|
self.src = f.read()
|
||||||
|
self.tree = ast.parse(self.src)
|
||||||
|
|
||||||
|
def test_module_level_functions(self):
|
||||||
|
mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)}
|
||||||
|
expected = {
|
||||||
|
"interpolate_pixels",
|
||||||
|
"_interpolate_single_band",
|
||||||
|
"_normalize_interpolation_method",
|
||||||
|
"_read_water_mask_to_array",
|
||||||
|
"_init_worker",
|
||||||
|
"_interpolate_block_worker",
|
||||||
|
"_process_one_block",
|
||||||
|
"interpolate_zero_pixels_batch",
|
||||||
|
}
|
||||||
|
self.assertEqual(mod_funcs, expected)
|
||||||
|
self.assertNotIn("_process_block_with_buffer", mod_funcs)
|
||||||
|
|
||||||
|
def test_worker_dataset_module_global(self):
|
||||||
|
mod_globals = set()
|
||||||
|
for n in self.tree.body:
|
||||||
|
if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name):
|
||||||
|
mod_globals.add(n.targets[0].id)
|
||||||
|
elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name):
|
||||||
|
mod_globals.add(n.target.id)
|
||||||
|
self.assertIn("_worker_dataset", mod_globals)
|
||||||
|
|
||||||
|
def test_interpolate_zero_pixels_batch_signature(self):
|
||||||
|
for n in self.tree.body:
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch":
|
||||||
|
args = [a.arg for a in n.args.args]
|
||||||
|
# Backward compat: all 8 existing params + 2 new
|
||||||
|
self.assertEqual(
|
||||||
|
args,
|
||||||
|
[
|
||||||
|
"img_path", "interpolation_method", "output_path",
|
||||||
|
"water_mask", "deglint_dir", "callback_progress",
|
||||||
|
"block_size", "halo_size", "n_workers", "use_multiprocessing",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
defaults = [getattr(d, "value", None) for d in n.args.defaults]
|
||||||
|
self.assertEqual(
|
||||||
|
defaults,
|
||||||
|
["nearest", None, None, None, None, 1024, 64, None, True],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self.fail("interpolate_zero_pixels_batch not found")
|
||||||
|
|
||||||
|
def test_init_worker_signature(self):
|
||||||
|
for n in self.tree.body:
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "_init_worker":
|
||||||
|
self.assertEqual([a.arg for a in n.args.args], ["img_path"])
|
||||||
|
return
|
||||||
|
self.fail("_init_worker not found")
|
||||||
|
|
||||||
|
def test_worker_function_signature(self):
|
||||||
|
for n in self.tree.body:
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker":
|
||||||
|
self.assertEqual([a.arg for a in n.args.args], ["task"])
|
||||||
|
return
|
||||||
|
self.fail("_interpolate_block_worker not found")
|
||||||
|
|
||||||
|
def test_process_one_block_signature(self):
|
||||||
|
for n in self.tree.body:
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block":
|
||||||
|
args = [a.arg for a in n.args.args]
|
||||||
|
self.assertEqual(
|
||||||
|
args,
|
||||||
|
[
|
||||||
|
"dataset", "x0", "y0",
|
||||||
|
"ey0", "ex0", "ey1", "ex1",
|
||||||
|
"row_offset", "col_offset",
|
||||||
|
"inner_h", "inner_w",
|
||||||
|
"mask_segment_ext", "method",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self.fail("_process_one_block not found")
|
||||||
|
|
||||||
|
def test_uses_process_pool_executor(self):
|
||||||
|
self.assertIn("ProcessPoolExecutor", self.src)
|
||||||
|
self.assertIn("ProcessPoolExecutor(", self.src)
|
||||||
|
self.assertIn("initializer=_init_worker", self.src)
|
||||||
|
self.assertIn("initargs=(img_path,)", self.src)
|
||||||
|
self.assertIn("GDAL_NUM_THREADS", self.src)
|
||||||
|
|
||||||
|
def test_dispatch_logic_present(self):
|
||||||
|
# Both serial and parallel paths should be present
|
||||||
|
self.assertIn("if effective_workers <= 1:", self.src)
|
||||||
|
self.assertIn("with ProcessPoolExecutor", self.src)
|
||||||
|
|
||||||
|
def test_serial_path_uses_process_one_block(self):
|
||||||
|
# Serial branch should call _process_one_block directly (no pickle overhead)
|
||||||
|
self.assertIn(
|
||||||
|
"_process_one_block(\n dataset, *task\n )",
|
||||||
|
self.src,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_worker_path_uses_process_one_block(self):
|
||||||
|
# Worker function should also call _process_one_block (shared core)
|
||||||
|
self.assertIn("_process_one_block(", self.src)
|
||||||
|
# The worker should read from _worker_dataset, not receive dataset in task
|
||||||
|
self.assertIn("_worker_dataset", self.src)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Mocked behavior tests (no real osgeo/scipy needed)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_band(read_data):
|
||||||
|
band = MagicMock()
|
||||||
|
band.ReadAsArray.return_value = read_data
|
||||||
|
return band
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_dataset(bands_data, n_bands=None):
|
||||||
|
if n_bands is None:
|
||||||
|
n_bands = len(bands_data)
|
||||||
|
ds = MagicMock()
|
||||||
|
ds.RasterCount = n_bands
|
||||||
|
ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1])
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_module(name, attrs=None):
|
||||||
|
mod = types.ModuleType(name)
|
||||||
|
for k, v in (attrs or {}).items():
|
||||||
|
setattr(mod, k, v)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def _install_fake_modules():
|
||||||
|
"""Install minimal fakes for scipy + osgeo so the module under test imports."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Fake scipy
|
||||||
|
scipy = types.ModuleType("scipy")
|
||||||
|
scipy_ndimage = types.ModuleType("scipy.ndimage")
|
||||||
|
scipy_interp = types.ModuleType("scipy.interpolate")
|
||||||
|
scipy_spatial = types.ModuleType("scipy.spatial")
|
||||||
|
|
||||||
|
class _FakeTree:
|
||||||
|
def __init__(self, coords):
|
||||||
|
self.coords = coords
|
||||||
|
def query(self, pts):
|
||||||
|
# Return index 0 for all queries
|
||||||
|
import numpy as np
|
||||||
|
n = len(pts)
|
||||||
|
return np.zeros(n, dtype=int), np.zeros(n, dtype=int)
|
||||||
|
|
||||||
|
def _griddata(coords, values, pts, method="linear", fill_value=0.0):
|
||||||
|
# Trivial: nearest valid value
|
||||||
|
if len(coords) == 0 or len(pts) == 0:
|
||||||
|
return np.zeros(len(pts), dtype=np.float32)
|
||||||
|
# Just return first value for all queries (testing only)
|
||||||
|
return np.full(len(pts), values[0], dtype=np.float32)
|
||||||
|
|
||||||
|
class _FakeRBF:
|
||||||
|
def __init__(self, coords, values, kernel=None):
|
||||||
|
self.first_value = values[0]
|
||||||
|
def __call__(self, pts):
|
||||||
|
return np.full(len(pts), self.first_value, dtype=np.float32)
|
||||||
|
|
||||||
|
scipy_spatial.cKDTree = _FakeTree
|
||||||
|
scipy_interp.griddata = _griddata
|
||||||
|
scipy_interp.RBFInterpolator = _FakeRBF
|
||||||
|
|
||||||
|
scipy.ndimage = scipy_ndimage
|
||||||
|
scipy.interpolate = scipy_interp
|
||||||
|
scipy.spatial = scipy_spatial
|
||||||
|
|
||||||
|
# Fake osgeo.gdal with the constants the module needs
|
||||||
|
osgeo = types.ModuleType("osgeo")
|
||||||
|
gdal_mod = types.ModuleType("osgeo.gdal")
|
||||||
|
gdal_mod.GA_ReadOnly = 0
|
||||||
|
gdal_mod.GDT_Float32 = 6
|
||||||
|
gdal_mod.UseExceptions = MagicMock()
|
||||||
|
|
||||||
|
def _open(path, mode):
|
||||||
|
return _make_mock_dataset([]) # empty dataset; real tests build their own
|
||||||
|
gdal_mod.Open = _open
|
||||||
|
gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock())
|
||||||
|
gdal_mod.SetConfigOption = MagicMock()
|
||||||
|
|
||||||
|
osgeo.gdal = gdal_mod
|
||||||
|
|
||||||
|
sys.modules["scipy"] = scipy
|
||||||
|
sys.modules["scipy.ndimage"] = scipy_ndimage
|
||||||
|
sys.modules["scipy.interpolate"] = scipy_interp
|
||||||
|
sys.modules["scipy.spatial"] = scipy_spatial
|
||||||
|
sys.modules["osgeo"] = osgeo
|
||||||
|
sys.modules["osgeo.gdal"] = gdal_mod
|
||||||
|
|
||||||
|
return {
|
||||||
|
"scipy": scipy,
|
||||||
|
"scipy.spatial": scipy_spatial,
|
||||||
|
"scipy.interpolate": scipy_interp,
|
||||||
|
"osgeo.gdal": gdal_mod,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessOneBlockMocked(unittest.TestCase):
|
||||||
|
"""Verify _process_one_block logic with mocked GDAL/scipy."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
_install_fake_modules()
|
||||||
|
# Import the module under test after installing fakes
|
||||||
|
from src.core.algorithms.interpolation import interpolator
|
||||||
|
cls.interp = interpolator
|
||||||
|
cls.gdal = sys.modules["osgeo.gdal"]
|
||||||
|
|
||||||
|
def _build_dataset(self, bands):
|
||||||
|
"""Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h)
|
||||||
|
returns bands[b-1] (a numpy array of the requested shape).
|
||||||
|
|
||||||
|
For simplicity we return the full band array regardless of x,y,w,h.
|
||||||
|
"""
|
||||||
|
ds = MagicMock()
|
||||||
|
ds.RasterCount = len(bands)
|
||||||
|
|
||||||
|
def get_band(b):
|
||||||
|
band = MagicMock()
|
||||||
|
band.ReadAsArray.return_value = bands[b - 1]
|
||||||
|
return band
|
||||||
|
ds.GetRasterBand.side_effect = get_band
|
||||||
|
return ds
|
||||||
|
|
||||||
|
def test_no_zeros_returns_inner_blocks_unchanged(self):
|
||||||
|
"""If no pixels are all-zero, inner blocks should be returned as-is."""
|
||||||
|
import numpy as np
|
||||||
|
# 2x2 image, 3 bands, all 1s (no zeros anywhere)
|
||||||
|
band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo
|
||||||
|
bands = [band, band * 2, band * 3]
|
||||||
|
ds = self._build_dataset(bands)
|
||||||
|
|
||||||
|
inner_bands, zero_count = self.interp._process_one_block(
|
||||||
|
dataset=ds,
|
||||||
|
x0=0, y0=0,
|
||||||
|
ey0=0, ex0=0, ey1=4, ex1=4,
|
||||||
|
row_offset=0, col_offset=0,
|
||||||
|
inner_h=2, inner_w=2,
|
||||||
|
mask_segment_ext=None,
|
||||||
|
method="nearest",
|
||||||
|
)
|
||||||
|
self.assertEqual(zero_count, 0)
|
||||||
|
self.assertEqual(len(inner_bands), 3)
|
||||||
|
for ib, expected_band in zip(inner_bands, bands):
|
||||||
|
self.assertEqual(ib.shape, (2, 2))
|
||||||
|
np.testing.assert_array_equal(ib, expected_band[:2, :2])
|
||||||
|
|
||||||
|
def test_with_zero_pixel_triggers_interpolation(self):
|
||||||
|
"""If a pixel is all-zero, _interpolate_single_band should be called."""
|
||||||
|
import numpy as np
|
||||||
|
# 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s.
|
||||||
|
band1 = np.array([
|
||||||
|
[0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 1],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
], dtype=np.float32)
|
||||||
|
band2 = band1 * 2
|
||||||
|
band3 = band1 * 3
|
||||||
|
ds = self._build_dataset([band1, band2, band3])
|
||||||
|
|
||||||
|
# Process the top-left 2x2 block (with halo, 4x4 covers all)
|
||||||
|
inner_bands, zero_count = self.interp._process_one_block(
|
||||||
|
dataset=ds,
|
||||||
|
x0=0, y0=0,
|
||||||
|
ey0=0, ex0=0, ey1=4, ex1=4,
|
||||||
|
row_offset=0, col_offset=0,
|
||||||
|
inner_h=2, inner_w=2,
|
||||||
|
mask_segment_ext=None,
|
||||||
|
method="nearest",
|
||||||
|
)
|
||||||
|
self.assertGreater(zero_count, 0, "should detect zero pixels")
|
||||||
|
self.assertEqual(len(inner_bands), 3)
|
||||||
|
# Each inner band should be 2x2
|
||||||
|
for ib in inner_bands:
|
||||||
|
self.assertEqual(ib.shape, (2, 2))
|
||||||
|
# Our fake nearest interpolation returns valid_values[0]=1 for band1
|
||||||
|
# So the inner block should be filled with non-zero values
|
||||||
|
self.assertTrue(np.all(inner_bands[0] > 0))
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerFunctionMocked(unittest.TestCase):
|
||||||
|
"""Verify _interpolate_block_worker uses module-global _worker_dataset."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
_install_fake_modules()
|
||||||
|
from src.core.algorithms.interpolation import interpolator
|
||||||
|
cls.interp = interpolator
|
||||||
|
|
||||||
|
def test_worker_uses_module_global_dataset(self):
|
||||||
|
import numpy as np
|
||||||
|
# Set the module global
|
||||||
|
band = np.array([
|
||||||
|
[0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 1],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
], dtype=np.float32)
|
||||||
|
ds = self._build_dataset([band, band * 2])
|
||||||
|
self.interp._worker_dataset = ds
|
||||||
|
try:
|
||||||
|
task = (
|
||||||
|
0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1
|
||||||
|
0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w
|
||||||
|
None, "nearest", # mask_segment_ext, method
|
||||||
|
)
|
||||||
|
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
|
||||||
|
self.assertIsNone(error)
|
||||||
|
self.assertEqual((x0, y0), (0, 0))
|
||||||
|
self.assertGreater(zero_count, 0)
|
||||||
|
self.assertIsNotNone(inner_bands)
|
||||||
|
self.assertEqual(len(inner_bands), 2)
|
||||||
|
finally:
|
||||||
|
self.interp._worker_dataset = None
|
||||||
|
|
||||||
|
def test_worker_returns_error_if_dataset_uninitialized(self):
|
||||||
|
self.interp._worker_dataset = None
|
||||||
|
task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest")
|
||||||
|
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
|
||||||
|
self.assertIsNotNone(error)
|
||||||
|
self.assertIn("not initialized", error)
|
||||||
|
self.assertIsNone(inner_bands)
|
||||||
|
|
||||||
|
def _build_dataset(self, bands):
|
||||||
|
ds = MagicMock()
|
||||||
|
ds.RasterCount = len(bands)
|
||||||
|
def get_band(b):
|
||||||
|
band = MagicMock()
|
||||||
|
band.ReadAsArray.return_value = bands[b - 1]
|
||||||
|
return band
|
||||||
|
ds.GetRasterBand.side_effect = get_band
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitWorkerMocked(unittest.TestCase):
|
||||||
|
"""Verify _init_worker sets module global and config option."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
_install_fake_modules()
|
||||||
|
from src.core.algorithms.interpolation import interpolator
|
||||||
|
cls.interp = interpolator
|
||||||
|
cls.gdal_mod = sys.modules["osgeo.gdal"]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.interp._worker_dataset = None
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.interp._worker_dataset = None
|
||||||
|
|
||||||
|
def test_init_worker_opens_and_caches(self):
|
||||||
|
fake_ds = MagicMock()
|
||||||
|
with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \
|
||||||
|
patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock:
|
||||||
|
self.interp._init_worker("/fake/path.bsq")
|
||||||
|
self.assertIs(self.interp._worker_dataset, fake_ds)
|
||||||
|
open_mock.assert_called_once_with("/fake/path.bsq", 0)
|
||||||
|
# Should have set GDAL_NUM_THREADS=1
|
||||||
|
cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1")
|
||||||
|
|
||||||
|
def test_init_worker_raises_if_open_fails(self):
|
||||||
|
with patch.object(self.gdal_mod, "Open", return_value=None):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
self.interp._init_worker("/bad/path.bsq")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
Reference in New Issue
Block a user