Step3 插值算法 OOM 修复 + 多进程加速 + 全链路累积改动(14 文件)
This commit is contained in:
@ -3,8 +3,24 @@
|
||||
|
||||
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
||||
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
||||
|
||||
本模块使用多进程并行分块 IO 加速(Plan A):
|
||||
- ProcessPoolExecutor 为每个 worker 进程打开一次源影像(initializer 阶段),
|
||||
避免每块重复 gdal.Open 带来的开销(Windows 上 ~50ms/次)
|
||||
- 主进程统一负责输出文件的写入,避免多进程写锁竞争
|
||||
- 分块大小(block_size)默认 1024,内存充足可调至 2048 / 4096
|
||||
|
||||
注意:
|
||||
- GDAL Dataset / Rasterio Dataset 对象不能跨进程传递(picking 不支持),
|
||||
所以 worker 必须在 init 阶段自己独立打开源文件
|
||||
- 每个 worker 强制设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 多线程
|
||||
造成的 CPU 过订阅
|
||||
- 关闭多进程:传 ``use_multiprocessing=False`` 或 ``n_workers=1``
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from pathlib import Path
|
||||
@ -24,6 +40,9 @@ except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
_worker_dataset: Optional["gdal.Dataset"] = None
|
||||
|
||||
|
||||
def interpolate_pixels(
|
||||
image_stack: np.ndarray,
|
||||
zero_coords: np.ndarray,
|
||||
@ -52,7 +71,6 @@ def interpolate_pixels(
|
||||
height, width, n_bands = image_stack.shape
|
||||
result = image_stack.copy()
|
||||
|
||||
# 兼容中文和各种格式的method参数
|
||||
raw_method = str(interpolation_method).lower()
|
||||
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
||||
method = 'nearest'
|
||||
@ -181,39 +199,271 @@ def _interpolate_single_band(
|
||||
return np.zeros(len(zero_coords))
|
||||
|
||||
|
||||
def _normalize_interpolation_method(method: str) -> str:
|
||||
"""将中文/英文混用的插值方法名归一化为内部标准名
|
||||
|
||||
支持: 'nearest'/'邻近'/'最邻近','bilinear'/'线性'/'双线性',
|
||||
'spline'/'样条'/'rbf','kriging'/'克里金'。
|
||||
"""
|
||||
raw = str(method).lower()
|
||||
if 'nearest' in raw or '邻近' in raw or '最邻近' in raw:
|
||||
return 'nearest'
|
||||
if 'bilinear' in raw or '线性' in raw or '双线性' in raw:
|
||||
return 'bilinear'
|
||||
if 'spline' in raw or '样条' in raw or 'rbf' in raw:
|
||||
return 'spline'
|
||||
if 'kriging' in raw or '克里金' in raw:
|
||||
return 'kriging'
|
||||
return 'nearest'
|
||||
|
||||
|
||||
def _read_water_mask_to_array(
|
||||
water_mask: Optional[Union[str, np.ndarray]],
|
||||
expected_height: int,
|
||||
expected_width: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""读取水域掩膜为 numpy 数组(单波段,bool/int 均可)
|
||||
|
||||
None 或空字符串直接返回 None。形状不匹配时给出告警但不抛错,
|
||||
让调用方按"无掩膜"路径继续。
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
if isinstance(water_mask, str):
|
||||
if not water_mask.strip():
|
||||
return None
|
||||
mask_ds = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_ds is None:
|
||||
print(f" [warn] 无法打开水域掩膜 {water_mask},按无掩膜处理")
|
||||
return None
|
||||
try:
|
||||
mask_array = mask_ds.GetRasterBand(1).ReadAsArray()
|
||||
finally:
|
||||
mask_ds = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
else:
|
||||
return None
|
||||
|
||||
if mask_array.shape != (expected_height, expected_width):
|
||||
print(
|
||||
f" [warn] 水域掩膜形状 {mask_array.shape} 与影像 "
|
||||
f"({expected_height}, {expected_width}) 不匹配,按无掩膜处理"
|
||||
)
|
||||
return None
|
||||
return mask_array
|
||||
|
||||
|
||||
def _init_worker(img_path: str) -> None:
|
||||
"""ProcessPoolExecutor initializer: 每个 worker 进程只调用一次
|
||||
|
||||
在 worker 进程启动时打开源影像 dataset 并缓存在模块全局变量
|
||||
``_worker_dataset`` 中。后续所有块处理直接复用这个 dataset,
|
||||
避免每块重复 ``gdal.Open``(Windows 上约 50ms/次,100 块即 5s)。
|
||||
|
||||
同时设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 默认多线程
|
||||
造成的 CPU 过订阅。
|
||||
"""
|
||||
global _worker_dataset
|
||||
gdal.SetConfigOption('GDAL_NUM_THREADS', '1')
|
||||
if hasattr(gdal, 'UseExceptions'):
|
||||
gdal.UseExceptions()
|
||||
_worker_dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if _worker_dataset is None:
|
||||
raise RuntimeError(f"Worker failed to open source image: {img_path}")
|
||||
|
||||
|
||||
def _interpolate_block_worker(task: tuple) -> tuple:
|
||||
"""ProcessPoolExecutor worker: 处理单个块并返回结果
|
||||
|
||||
该函数必须保持模块级(可被 pickle),不持有任何外部状态——
|
||||
源 dataset 通过 ``_worker_dataset`` 模块全局变量获取。
|
||||
|
||||
Returns:
|
||||
``(x0, y0, inner_bands, zero_count, error_msg)`` 元组:
|
||||
- x0, y0: 块在影像中的写入起点
|
||||
- inner_bands: ``List[np.ndarray]``,每个元素是 (inner_h, inner_w)
|
||||
float32 数组(每个波段一个),或失败时为 None
|
||||
- zero_count: 该扩展块中识别到的零像素数(含 halo 范围)
|
||||
- error_msg: None 表示成功,str 表示错误信息
|
||||
"""
|
||||
(
|
||||
x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
) = task
|
||||
if _worker_dataset is None:
|
||||
return (x0, y0, None, 0, "Worker dataset not initialized")
|
||||
try:
|
||||
inner_bands, zero_count = _process_one_block(
|
||||
_worker_dataset, x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
)
|
||||
return (x0, y0, inner_bands, zero_count, None)
|
||||
except Exception as e:
|
||||
return (x0, y0, None, 0, str(e))
|
||||
|
||||
|
||||
def _process_one_block(
|
||||
dataset: "gdal.Dataset",
|
||||
x0: int, y0: int,
|
||||
ey0: int, ex0: int, ey1: int, ex1: int,
|
||||
row_offset: int, col_offset: int,
|
||||
inner_h: int, inner_w: int,
|
||||
mask_segment_ext: Optional[np.ndarray],
|
||||
method: str,
|
||||
) -> Tuple[List[np.ndarray], int]:
|
||||
"""处理单个扩展块(纯计算核心,dataset 显式传入)
|
||||
|
||||
串行模式和并行模式共用此函数。并行模式下 dataset 来自 worker 的
|
||||
缓存(``_worker_dataset``),串行模式下 dataset 由主函数传入。
|
||||
|
||||
Args:
|
||||
dataset: 已打开的源影像 dataset
|
||||
x0, y0: 内部块左上角(写入位置)
|
||||
ey0, ex0, ey1, ex1: 扩展块(含 halo)坐标
|
||||
row_offset, col_offset: 内部块在扩展块中的偏移
|
||||
inner_h, inner_w: 内部块尺寸
|
||||
mask_segment_ext: 扩展块对应的水域掩膜(None 表示不应用)
|
||||
method: 插值方法(已归一化)
|
||||
|
||||
Returns:
|
||||
``(inner_bands, zero_count)`` 元组:
|
||||
- inner_bands: ``List[np.ndarray]``,长度 = n_bands,每个元素形状为
|
||||
``(inner_h, inner_w)`` 的 float32 数组
|
||||
- zero_count: 扩展块中识别到的零像素数
|
||||
"""
|
||||
n_bands = dataset.RasterCount
|
||||
ext_bands: List[np.ndarray] = []
|
||||
for b in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(b)
|
||||
ext_bands.append(
|
||||
band.ReadAsArray(ex0, ey0, ex1 - ex0, ey1 - ey0).astype(np.float32)
|
||||
)
|
||||
band = None
|
||||
|
||||
try:
|
||||
ext_h, ext_w = ey1 - ey0, ex1 - ex0
|
||||
|
||||
all_zero_ext = np.ones((ext_h, ext_w), dtype=bool)
|
||||
for b_data in ext_bands:
|
||||
all_zero_ext &= (b_data == 0)
|
||||
|
||||
if mask_segment_ext is not None:
|
||||
all_zero_ext &= (mask_segment_ext > 0)
|
||||
|
||||
zero_count = int(np.sum(all_zero_ext))
|
||||
|
||||
if zero_count == 0:
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, 0
|
||||
|
||||
zero_y, zero_x = np.where(all_zero_ext)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_zero_ext
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
print(
|
||||
f" [warn] 块 (y={y0}-{y0 + inner_h}, x={x0}-{x0 + inner_w}) "
|
||||
f"无有效像素可作插值上下文,已跳过"
|
||||
)
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, zero_count
|
||||
|
||||
for b in range(n_bands):
|
||||
ext_band = ext_bands[b]
|
||||
valid_values_band = ext_band[valid_mask]
|
||||
if len(valid_values_band) == 0:
|
||||
continue
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, method
|
||||
)
|
||||
ext_band[zero_y, zero_x] = band_result
|
||||
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, zero_count
|
||||
finally:
|
||||
del ext_bands
|
||||
|
||||
|
||||
def interpolate_zero_pixels_batch(
|
||||
img_path: str,
|
||||
interpolation_method: str = 'nearest',
|
||||
output_path: Optional[str] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
deglint_dir: Optional[str] = None,
|
||||
callback_progress: Optional[callable] = None
|
||||
callback_progress: Optional[callable] = None,
|
||||
block_size: int = 1024,
|
||||
halo_size: int = 64,
|
||||
n_workers: Optional[int] = None,
|
||||
use_multiprocessing: bool = True,
|
||||
) -> Tuple[str, Optional[np.ndarray]]:
|
||||
"""
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)。
|
||||
|
||||
采用 **分块 IO + 多进程并行** 策略:
|
||||
1. 影像按 ``block_size`` × ``block_size`` 分块,每块边界外扩展
|
||||
``halo_size`` 像素作为插值上下文,避免块边缘插值退化
|
||||
2. 多进程并行(默认 ``ProcessPoolExecutor``,worker 数 = CPU 核心数)
|
||||
并发处理所有块;GDAL Dataset 不能跨进程传递,所以每个 worker
|
||||
在 ``initializer`` 阶段独立打开源文件一次并缓存
|
||||
3. 主进程按块序接收处理结果并统一写入输出文件,避免写锁竞争
|
||||
4. 该方案可彻底避免一次性读取 50 波段整景影像时的 OOM 隐患
|
||||
(50 波段 × 4000×4000 × float32 ≈ 3GB 的 np.dstack)
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组)
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline',
|
||||
'kriging' 及其中文别名('邻近'/'最邻近'/'线性'/'双线性'/'样条'/'克里金')
|
||||
output_path: 输出文件路径(如果为 None 且 deglint_dir 提供,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组),形状须与影像高宽一致
|
||||
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
||||
callback_progress: 进度回调函数
|
||||
callback_progress: 进度回调函数,签名 ``callback(msg: str)``
|
||||
block_size: 分块大小(像素),默认 1024;内存充足可调 2048/4096
|
||||
halo_size: 上下文 halo 宽度(像素),默认 64
|
||||
n_workers: 并行 worker 进程数;None = ``multiprocessing.cpu_count()``;
|
||||
传 1 等价于串行模式
|
||||
use_multiprocessing: 是否启用多进程;False 时强制串行
|
||||
|
||||
Returns:
|
||||
(output_path, interpolated_image_stack) 元组
|
||||
``(output_path, None)`` 元组。第二个值固定为 ``None``(与原版语义保留
|
||||
兼容;返回完整内存堆叠会重新引入 OOM 风险,故不再提供)。
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
# 确定输出路径
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
|
||||
method = _normalize_interpolation_method(interpolation_method)
|
||||
|
||||
# 检查文件是否已存在
|
||||
if output_path and Path(output_path).exists():
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{method}.bsq")
|
||||
if output_path is None:
|
||||
raise ValueError("output_path 和 deglint_dir 至少需要指定一个")
|
||||
|
||||
if Path(output_path).exists():
|
||||
return output_path, None
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
@ -227,94 +477,125 @@ def interpolate_zero_pixels_batch(
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
# 读取所有波段数据
|
||||
all_bands = []
|
||||
for band_idx in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(band_idx)
|
||||
band_data = band.ReadAsArray().astype(np.float32)
|
||||
all_bands.append(band_data)
|
||||
|
||||
image_stack = np.dstack(all_bands)
|
||||
|
||||
# 读取水域掩膜
|
||||
mask_array = None
|
||||
if water_mask is not None:
|
||||
if isinstance(water_mask, str):
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
|
||||
# 找出所有波段都为0的像素点
|
||||
all_bands_zero = np.all(image_stack == 0, axis=2)
|
||||
|
||||
if mask_array is not None:
|
||||
all_bands_zero = all_bands_zero & (mask_array > 0)
|
||||
|
||||
zero_pixel_count = np.sum(all_bands_zero)
|
||||
if zero_pixel_count == 0:
|
||||
# 无需插值,直接保存
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(all_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
return output_path, image_stack
|
||||
|
||||
# 获取坐标
|
||||
zero_y, zero_x = np.where(all_bands_zero)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_bands_zero
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
raise ValueError("没有有效像素可用于插值")
|
||||
|
||||
# 逐波段插值
|
||||
interpolated_bands = []
|
||||
for band_idx in range(n_bands):
|
||||
if callback_progress:
|
||||
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
|
||||
band_data = all_bands[band_idx].copy()
|
||||
valid_values_band = band_data[valid_mask]
|
||||
|
||||
if len(valid_values_band) == 0:
|
||||
interpolated_bands.append(band_data)
|
||||
continue
|
||||
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, interpolation_method
|
||||
if width <= 0 or height <= 0 or n_bands <= 0:
|
||||
raise ValueError(
|
||||
f"影像尺寸异常: width={width}, height={height}, n_bands={n_bands}"
|
||||
)
|
||||
band_data[all_bands_zero] = band_result
|
||||
interpolated_bands.append(band_data)
|
||||
|
||||
# 保存结果
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(interpolated_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
mask_array = _read_water_mask_to_array(water_mask, height, width)
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
if driver is None:
|
||||
raise RuntimeError("未找到可用的栅格驱动(ENVI / GTiff 都不存在)")
|
||||
|
||||
out_dataset = driver.Create(
|
||||
output_path, width, height, n_bands, gdal.GDT_Float32
|
||||
)
|
||||
if out_dataset is None:
|
||||
raise RuntimeError(f"无法创建输出文件: {output_path}")
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
|
||||
try:
|
||||
if not use_multiprocessing:
|
||||
effective_workers = 1
|
||||
elif n_workers is not None and n_workers >= 1:
|
||||
effective_workers = int(n_workers)
|
||||
else:
|
||||
try:
|
||||
cpu_count = multiprocessing.cpu_count() or 1
|
||||
except (NotImplementedError, OSError):
|
||||
cpu_count = 1
|
||||
effective_workers = max(1, cpu_count)
|
||||
|
||||
n_blocks_y = (height + block_size - 1) // block_size
|
||||
n_blocks_x = (width + block_size - 1) // block_size
|
||||
total_blocks = n_blocks_y * n_blocks_x
|
||||
|
||||
tasks = []
|
||||
for by in range(n_blocks_y):
|
||||
y0 = by * block_size
|
||||
y1 = min(y0 + block_size, height)
|
||||
inner_h = y1 - y0
|
||||
ey0 = max(0, y0 - halo_size)
|
||||
ey1 = min(height, y1 + halo_size)
|
||||
for bx in range(n_blocks_x):
|
||||
x0 = bx * block_size
|
||||
x1 = min(x0 + block_size, width)
|
||||
inner_w = x1 - x0
|
||||
ex0 = max(0, x0 - halo_size)
|
||||
ex1 = min(width, x1 + halo_size)
|
||||
row_offset = y0 - ey0
|
||||
col_offset = x0 - ex0
|
||||
mask_segment_ext = None
|
||||
if mask_array is not None:
|
||||
mask_segment_ext = mask_array[ey0:ey1, ex0:ex1]
|
||||
tasks.append((
|
||||
x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
))
|
||||
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"分块插值开始: 共 {total_blocks} 块 "
|
||||
f"(block_size={block_size}, halo={halo_size}, method={method}, "
|
||||
f"workers={effective_workers})"
|
||||
)
|
||||
|
||||
total_zero_pixels = 0
|
||||
|
||||
if effective_workers <= 1:
|
||||
for block_idx, task in enumerate(tasks, 1):
|
||||
x0_t, y0_t = task[0], task[1]
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"块 {block_idx}/{total_blocks} "
|
||||
f"y=[{y0_t},{y0_t + task[8]}) x=[{x0_t},{x0_t + task[9]})"
|
||||
)
|
||||
inner_bands, zero_count = _process_one_block(
|
||||
dataset, *task
|
||||
)
|
||||
for b_idx, band_data in enumerate(inner_bands):
|
||||
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
|
||||
band_data, xoff=x0_t, yoff=y0_t
|
||||
)
|
||||
total_zero_pixels += zero_count
|
||||
else:
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=effective_workers,
|
||||
initializer=_init_worker,
|
||||
initargs=(img_path,),
|
||||
) as executor:
|
||||
futures = [
|
||||
executor.submit(_interpolate_block_worker, task)
|
||||
for task in tasks
|
||||
]
|
||||
for block_idx, future in enumerate(futures, 1):
|
||||
x0_t, y0_t, inner_bands, zero_count, error = future.result()
|
||||
if error is not None:
|
||||
raise RuntimeError(
|
||||
f"块 (y={y0_t}, x={x0_t}) 处理失败: {error}"
|
||||
)
|
||||
if inner_bands is not None:
|
||||
for b_idx, band_data in enumerate(inner_bands):
|
||||
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
|
||||
band_data, xoff=x0_t, yoff=y0_t
|
||||
)
|
||||
total_zero_pixels += zero_count
|
||||
if callback_progress:
|
||||
callback_progress(f"已写入块 {block_idx}/{total_blocks}")
|
||||
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"分块插值完成: 共处理 {total_zero_pixels} 个零像素 "
|
||||
f"({total_blocks} 块,方法 {method},workers={effective_workers})"
|
||||
)
|
||||
|
||||
return output_path, None
|
||||
finally:
|
||||
out_dataset = None
|
||||
|
||||
result_stack = np.dstack(interpolated_bands)
|
||||
return output_path, result_stack
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
@ -899,7 +899,7 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
||||
if __name__ == '__main__':
|
||||
# 在这里直接设置参数
|
||||
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
|
||||
|
||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||
|
||||
@ -806,8 +806,8 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
||||
if __name__ == '__main__':
|
||||
# 在这里直接设置参数
|
||||
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\work_dir\6_Spectral_Feature_Extraction/yangdian_output.csv" # CSV格式输出文件路径
|
||||
|
||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||
|
||||
@ -315,7 +315,7 @@ def main():
|
||||
|
||||
# 示例1: 使用所有回归方法分析光谱指数
|
||||
print("\n1. 光谱指数与叶绿素a的回归分析:")
|
||||
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv")
|
||||
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\water_quality_results.csv")
|
||||
spectral_indices = ['Al10SABI','Am092Bsub']
|
||||
|
||||
results1 = analyzer.batch_single_variable_regression(
|
||||
@ -323,7 +323,7 @@ def main():
|
||||
x_columns=spectral_indices,
|
||||
y_column='Chlorophyll',
|
||||
methods='all',
|
||||
output_file=r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\spectral_indices_regression.csv'
|
||||
output_file=r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\spectral_indices_regression.csv'
|
||||
)
|
||||
|
||||
# # 示例2: 使用特定方法分析反射率波段
|
||||
@ -343,7 +343,7 @@ def main():
|
||||
best_models = analyzer.get_best_models_summary()
|
||||
if not best_models.empty:
|
||||
print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False))
|
||||
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\best_models_summary.csv', index=False)
|
||||
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\best_models_summary.csv', index=False)
|
||||
print("\n最佳模型汇总已保存到 'best_models_summary.csv'")
|
||||
#
|
||||
# def advanced_usage_example():
|
||||
|
||||
@ -246,7 +246,7 @@ def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, outp
|
||||
|
||||
if __name__ == "__main__":
|
||||
algorithm= "chl_a"
|
||||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json"
|
||||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\8_non_empirical_models\SS\SS_chl_a.json"
|
||||
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
|
||||
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
|
||||
wave_radius=5.0
|
||||
|
||||
@ -98,7 +98,7 @@ PIPELINE_STEPS: List[StepSpec] = [
|
||||
step_id="step4", method_name="step5_process_csv",
|
||||
requires=["csv_path"], produces=["processed_csv_path"],
|
||||
required_input_files=["csv_path"],
|
||||
output_file="{work_dir}/4_processed_data/processed_data.csv",
|
||||
output_file="{work_dir}/5_Data_Cleaning/processed_data.csv",
|
||||
description="CSV 异常值清洗",
|
||||
),
|
||||
StepSpec(
|
||||
@ -111,21 +111,21 @@ PIPELINE_STEPS: List[StepSpec] = [
|
||||
},
|
||||
skip_when_missing=False,
|
||||
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
|
||||
output_file="{work_dir}/5_training_spectra/training_spectra.csv",
|
||||
output_file="{work_dir}/6_Spectral_Feature_Extraction/training_spectra.csv",
|
||||
description="实测样本点光谱提取",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step7", method_name="step7_calc_indices",
|
||||
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv",
|
||||
output_file="{work_dir}/7_Water_Quality_Indices/training_spectra_indices.csv",
|
||||
description="水质参数指数计算(双轨输出:A轨宽表 + B轨单文件)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8", method_name="step8_train_ml",
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl",
|
||||
output_file="{work_dir}/8_Supervised_Model_Training/best_models.pkl",
|
||||
description="ML 建模(GridSearchCV / AutoML)",
|
||||
),
|
||||
StepSpec(
|
||||
@ -134,7 +134,7 @@ PIPELINE_STEPS: List[StepSpec] = [
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
parameter_map={"training_csv_path": "csv_path"},
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl",
|
||||
output_file="{work_dir}/8_Non_Empirical_Regression/non_empirical_models.pkl",
|
||||
description="非经验统计回归",
|
||||
),
|
||||
StepSpec(
|
||||
|
||||
@ -328,7 +328,7 @@ def train_with_automl(
|
||||
split_method = split_methods[0]
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = "./7_Supervised_Model_Training_AutoML"
|
||||
output_dir = "./8_Supervised_Model_Training_AutoML"
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
preproc_dir = out_dir / preproc
|
||||
@ -519,7 +519,7 @@ if __name__ == "__main__":
|
||||
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
|
||||
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
|
||||
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
|
||||
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
|
||||
p.add_argument("--out", default="./8_Supervised_Model_Training_AutoML")
|
||||
args = p.parse_args()
|
||||
|
||||
# 智能推断 feature_start_column 类型
|
||||
|
||||
@ -21,7 +21,7 @@ class DataPreparationStep:
|
||||
@staticmethod
|
||||
def process_csv(
|
||||
csv_path: str,
|
||||
output_dir: Union[str, Path] = "./4_processed_data",
|
||||
output_dir: Union[str, Path] = "./5_Data_Cleaning",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""处理CSV文件(筛选剔除异常值)"""
|
||||
@ -61,7 +61,7 @@ class DataPreparationStep:
|
||||
boundary_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./5_training_spectra",
|
||||
output_dir: Union[str, Path] = "./6_Spectral_Feature_Extraction",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
||||
@ -131,7 +131,7 @@ class DataPreparationStep:
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
output_dir: Union[str, Path] = "./6_water_quality_indices",
|
||||
output_dir: Union[str, Path] = "./7_Water_Quality_Indices",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> Optional[str]:
|
||||
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
||||
|
||||
@ -135,7 +135,7 @@ class ModelingStep:
|
||||
split_methods: Optional[List[str]] = None,
|
||||
cv_folds: int = 5,
|
||||
training_csv_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
|
||||
output_dir: Union[str, Path] = "./8_Supervised_Model_Training",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
) -> str:
|
||||
@ -251,7 +251,7 @@ class ModelingStep:
|
||||
if output_dir is not None:
|
||||
non_empirical_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
|
||||
non_empirical_dir = Path.cwd() / "8_Non_Empirical_Regression"
|
||||
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if preprocessing_methods is None:
|
||||
@ -430,7 +430,7 @@ def _apply_preprocessing_internal(
|
||||
|
||||
save_path = None
|
||||
if preprocess_method == "SS":
|
||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
|
||||
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training"
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_path = str(models_dir / "scaler_params.pkl")
|
||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||
|
||||
@ -259,7 +259,7 @@ class PredictionStep:
|
||||
if non_empirical_models_dir is not None:
|
||||
final_models_dir = non_empirical_models_dir
|
||||
else:
|
||||
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
|
||||
default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression")
|
||||
if Path(default_models_dir).exists():
|
||||
final_models_dir = default_models_dir
|
||||
else:
|
||||
|
||||
@ -138,11 +138,11 @@ class WaterQualityInversionPipeline:
|
||||
self.water_mask_dir = self.work_dir / "1_water_mask"
|
||||
self.glint_dir = self.work_dir / "2_Glint_Detection"
|
||||
self.deglint_dir = self.work_dir / "3_deglint"
|
||||
self.processed_data_dir = self.work_dir / "4_processed_data"
|
||||
self.training_spectra_dir = self.work_dir / "5_training_spectra"
|
||||
self.indices_dir = self.work_dir / "6_water_quality_indices"
|
||||
self.models_dir = self.work_dir / "7_Supervised_Model_Training"
|
||||
self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling"
|
||||
self.processed_data_dir = self.work_dir / "5_Data_Cleaning"
|
||||
self.training_spectra_dir = self.work_dir / "6_Spectral_Feature_Extraction"
|
||||
self.indices_dir = self.work_dir / "7_Water_Quality_Indices"
|
||||
self.models_dir = self.work_dir / "8_Supervised_Model_Training"
|
||||
self.non_empirical_models_dir = self.work_dir / "8_Non_Empirical_Regression"
|
||||
self.custom_regression_dir = self.work_dir / "13_Custom_Regression"
|
||||
self.sampling_dir = self.work_dir / "4_sampling"
|
||||
self.prediction_dir = self.work_dir / "11_12_13_predictions"
|
||||
@ -764,7 +764,7 @@ class WaterQualityInversionPipeline:
|
||||
if not spectrum_csv or not os.path.exists(spectrum_csv):
|
||||
# 回退:扫描 work_dir 下 step5 的产物目录,找第一个 .csv
|
||||
fallback_candidates = []
|
||||
step5_dir = os.path.join(self.work_dir, "5_Training_Spectra")
|
||||
step5_dir = os.path.join(self.work_dir, "6_Spectral_Feature_Extraction")
|
||||
if os.path.isdir(step5_dir):
|
||||
for f in sorted(os.listdir(step5_dir)):
|
||||
if f.lower().endswith('.csv'):
|
||||
@ -2023,10 +2023,10 @@ class WaterQualityInversionPipeline:
|
||||
# 应用预处理 - 使用spectral_Preprocessing模块
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
# 为SS预处理提供scaler保存路径,保存在工作目录的7_Supervised_Model_Training中
|
||||
# 为SS预处理提供scaler保存路径,保存在工作目录的8_Supervised_Model_Training中
|
||||
save_path = None
|
||||
if preprocess_method == 'SS':
|
||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录
|
||||
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training" # 向上两级到工作目录
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_path = str(models_dir / "scaler_params.pkl")
|
||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||
|
||||
Reference in New Issue
Block a user