Step3 插值算法 OOM 修复 + 多进程加速 + 全链路累积改动(14 文件)

This commit is contained in:
DXC
2026-06-15 16:49:17 +08:00
parent 82e0b92af6
commit 60a9d7d922
14 changed files with 855 additions and 152 deletions

View File

@ -0,0 +1,411 @@
"""
interpolator.py 多进程重构的行为测试mock-based无 osgeo 依赖)
验证:
1. 静态结构模块级函数集合、签名、新参数、_worker_dataset 全局
2. 行为逻辑_process_one_block 在 mock dataset 上的零像素识别 + 插值
3. 行为逻辑_interpolate_block_worker 通过模块全局 _worker_dataset 工作
4. 行为逻辑interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo
5. 向后兼容:现有 8 个 caller 参数保留默认值不变
如果本机有 osgeo可补一个真实数据集的 smoke test。
"""
import os
import sys
import ast
import types
import unittest
from unittest.mock import MagicMock, patch
# Ensure the project src is on path so we can import the module under test
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
INTERP_PATH = os.path.join(
PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py"
)
# =============================================================================
# Static structure tests
# =============================================================================
class TestInterpolatorStructure(unittest.TestCase):
def setUp(self):
with open(INTERP_PATH, "r", encoding="utf-8") as f:
self.src = f.read()
self.tree = ast.parse(self.src)
def test_module_level_functions(self):
mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)}
expected = {
"interpolate_pixels",
"_interpolate_single_band",
"_normalize_interpolation_method",
"_read_water_mask_to_array",
"_init_worker",
"_interpolate_block_worker",
"_process_one_block",
"interpolate_zero_pixels_batch",
}
self.assertEqual(mod_funcs, expected)
self.assertNotIn("_process_block_with_buffer", mod_funcs)
def test_worker_dataset_module_global(self):
mod_globals = set()
for n in self.tree.body:
if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name):
mod_globals.add(n.targets[0].id)
elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name):
mod_globals.add(n.target.id)
self.assertIn("_worker_dataset", mod_globals)
def test_interpolate_zero_pixels_batch_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch":
args = [a.arg for a in n.args.args]
# Backward compat: all 8 existing params + 2 new
self.assertEqual(
args,
[
"img_path", "interpolation_method", "output_path",
"water_mask", "deglint_dir", "callback_progress",
"block_size", "halo_size", "n_workers", "use_multiprocessing",
],
)
defaults = [getattr(d, "value", None) for d in n.args.defaults]
self.assertEqual(
defaults,
["nearest", None, None, None, None, 1024, 64, None, True],
)
return
self.fail("interpolate_zero_pixels_batch not found")
def test_init_worker_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_init_worker":
self.assertEqual([a.arg for a in n.args.args], ["img_path"])
return
self.fail("_init_worker not found")
def test_worker_function_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker":
self.assertEqual([a.arg for a in n.args.args], ["task"])
return
self.fail("_interpolate_block_worker not found")
def test_process_one_block_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block":
args = [a.arg for a in n.args.args]
self.assertEqual(
args,
[
"dataset", "x0", "y0",
"ey0", "ex0", "ey1", "ex1",
"row_offset", "col_offset",
"inner_h", "inner_w",
"mask_segment_ext", "method",
],
)
return
self.fail("_process_one_block not found")
def test_uses_process_pool_executor(self):
self.assertIn("ProcessPoolExecutor", self.src)
self.assertIn("ProcessPoolExecutor(", self.src)
self.assertIn("initializer=_init_worker", self.src)
self.assertIn("initargs=(img_path,)", self.src)
self.assertIn("GDAL_NUM_THREADS", self.src)
def test_dispatch_logic_present(self):
# Both serial and parallel paths should be present
self.assertIn("if effective_workers <= 1:", self.src)
self.assertIn("with ProcessPoolExecutor", self.src)
def test_serial_path_uses_process_one_block(self):
# Serial branch should call _process_one_block directly (no pickle overhead)
self.assertIn(
"_process_one_block(\n dataset, *task\n )",
self.src,
)
def test_worker_path_uses_process_one_block(self):
# Worker function should also call _process_one_block (shared core)
self.assertIn("_process_one_block(", self.src)
# The worker should read from _worker_dataset, not receive dataset in task
self.assertIn("_worker_dataset", self.src)
# =============================================================================
# Mocked behavior tests (no real osgeo/scipy needed)
# =============================================================================
def _make_mock_band(read_data):
band = MagicMock()
band.ReadAsArray.return_value = read_data
return band
def _make_mock_dataset(bands_data, n_bands=None):
if n_bands is None:
n_bands = len(bands_data)
ds = MagicMock()
ds.RasterCount = n_bands
ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1])
return ds
def _make_fake_module(name, attrs=None):
mod = types.ModuleType(name)
for k, v in (attrs or {}).items():
setattr(mod, k, v)
return mod
def _install_fake_modules():
"""Install minimal fakes for scipy + osgeo so the module under test imports."""
import numpy as np
# Fake scipy
scipy = types.ModuleType("scipy")
scipy_ndimage = types.ModuleType("scipy.ndimage")
scipy_interp = types.ModuleType("scipy.interpolate")
scipy_spatial = types.ModuleType("scipy.spatial")
class _FakeTree:
def __init__(self, coords):
self.coords = coords
def query(self, pts):
# Return index 0 for all queries
import numpy as np
n = len(pts)
return np.zeros(n, dtype=int), np.zeros(n, dtype=int)
def _griddata(coords, values, pts, method="linear", fill_value=0.0):
# Trivial: nearest valid value
if len(coords) == 0 or len(pts) == 0:
return np.zeros(len(pts), dtype=np.float32)
# Just return first value for all queries (testing only)
return np.full(len(pts), values[0], dtype=np.float32)
class _FakeRBF:
def __init__(self, coords, values, kernel=None):
self.first_value = values[0]
def __call__(self, pts):
return np.full(len(pts), self.first_value, dtype=np.float32)
scipy_spatial.cKDTree = _FakeTree
scipy_interp.griddata = _griddata
scipy_interp.RBFInterpolator = _FakeRBF
scipy.ndimage = scipy_ndimage
scipy.interpolate = scipy_interp
scipy.spatial = scipy_spatial
# Fake osgeo.gdal with the constants the module needs
osgeo = types.ModuleType("osgeo")
gdal_mod = types.ModuleType("osgeo.gdal")
gdal_mod.GA_ReadOnly = 0
gdal_mod.GDT_Float32 = 6
gdal_mod.UseExceptions = MagicMock()
def _open(path, mode):
return _make_mock_dataset([]) # empty dataset; real tests build their own
gdal_mod.Open = _open
gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock())
gdal_mod.SetConfigOption = MagicMock()
osgeo.gdal = gdal_mod
sys.modules["scipy"] = scipy
sys.modules["scipy.ndimage"] = scipy_ndimage
sys.modules["scipy.interpolate"] = scipy_interp
sys.modules["scipy.spatial"] = scipy_spatial
sys.modules["osgeo"] = osgeo
sys.modules["osgeo.gdal"] = gdal_mod
return {
"scipy": scipy,
"scipy.spatial": scipy_spatial,
"scipy.interpolate": scipy_interp,
"osgeo.gdal": gdal_mod,
}
class TestProcessOneBlockMocked(unittest.TestCase):
"""Verify _process_one_block logic with mocked GDAL/scipy."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
# Import the module under test after installing fakes
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
cls.gdal = sys.modules["osgeo.gdal"]
def _build_dataset(self, bands):
"""Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h)
returns bands[b-1] (a numpy array of the requested shape).
For simplicity we return the full band array regardless of x,y,w,h.
"""
ds = MagicMock()
ds.RasterCount = len(bands)
def get_band(b):
band = MagicMock()
band.ReadAsArray.return_value = bands[b - 1]
return band
ds.GetRasterBand.side_effect = get_band
return ds
def test_no_zeros_returns_inner_blocks_unchanged(self):
"""If no pixels are all-zero, inner blocks should be returned as-is."""
import numpy as np
# 2x2 image, 3 bands, all 1s (no zeros anywhere)
band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo
bands = [band, band * 2, band * 3]
ds = self._build_dataset(bands)
inner_bands, zero_count = self.interp._process_one_block(
dataset=ds,
x0=0, y0=0,
ey0=0, ex0=0, ey1=4, ex1=4,
row_offset=0, col_offset=0,
inner_h=2, inner_w=2,
mask_segment_ext=None,
method="nearest",
)
self.assertEqual(zero_count, 0)
self.assertEqual(len(inner_bands), 3)
for ib, expected_band in zip(inner_bands, bands):
self.assertEqual(ib.shape, (2, 2))
np.testing.assert_array_equal(ib, expected_band[:2, :2])
def test_with_zero_pixel_triggers_interpolation(self):
"""If a pixel is all-zero, _interpolate_single_band should be called."""
import numpy as np
# 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s.
band1 = np.array([
[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
], dtype=np.float32)
band2 = band1 * 2
band3 = band1 * 3
ds = self._build_dataset([band1, band2, band3])
# Process the top-left 2x2 block (with halo, 4x4 covers all)
inner_bands, zero_count = self.interp._process_one_block(
dataset=ds,
x0=0, y0=0,
ey0=0, ex0=0, ey1=4, ex1=4,
row_offset=0, col_offset=0,
inner_h=2, inner_w=2,
mask_segment_ext=None,
method="nearest",
)
self.assertGreater(zero_count, 0, "should detect zero pixels")
self.assertEqual(len(inner_bands), 3)
# Each inner band should be 2x2
for ib in inner_bands:
self.assertEqual(ib.shape, (2, 2))
# Our fake nearest interpolation returns valid_values[0]=1 for band1
# So the inner block should be filled with non-zero values
self.assertTrue(np.all(inner_bands[0] > 0))
class TestWorkerFunctionMocked(unittest.TestCase):
"""Verify _interpolate_block_worker uses module-global _worker_dataset."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
def test_worker_uses_module_global_dataset(self):
import numpy as np
# Set the module global
band = np.array([
[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
], dtype=np.float32)
ds = self._build_dataset([band, band * 2])
self.interp._worker_dataset = ds
try:
task = (
0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1
0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w
None, "nearest", # mask_segment_ext, method
)
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
self.assertIsNone(error)
self.assertEqual((x0, y0), (0, 0))
self.assertGreater(zero_count, 0)
self.assertIsNotNone(inner_bands)
self.assertEqual(len(inner_bands), 2)
finally:
self.interp._worker_dataset = None
def test_worker_returns_error_if_dataset_uninitialized(self):
self.interp._worker_dataset = None
task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest")
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
self.assertIsNotNone(error)
self.assertIn("not initialized", error)
self.assertIsNone(inner_bands)
def _build_dataset(self, bands):
ds = MagicMock()
ds.RasterCount = len(bands)
def get_band(b):
band = MagicMock()
band.ReadAsArray.return_value = bands[b - 1]
return band
ds.GetRasterBand.side_effect = get_band
return ds
class TestInitWorkerMocked(unittest.TestCase):
"""Verify _init_worker sets module global and config option."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
cls.gdal_mod = sys.modules["osgeo.gdal"]
def setUp(self):
self.interp._worker_dataset = None
def tearDown(self):
self.interp._worker_dataset = None
def test_init_worker_opens_and_caches(self):
fake_ds = MagicMock()
with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \
patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock:
self.interp._init_worker("/fake/path.bsq")
self.assertIs(self.interp._worker_dataset, fake_ds)
open_mock.assert_called_once_with("/fake/path.bsq", 0)
# Should have set GDAL_NUM_THREADS=1
cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1")
def test_init_worker_raises_if_open_fails(self):
with patch.object(self.gdal_mod, "Open", return_value=None):
with self.assertRaises(RuntimeError):
self.interp._init_worker("/bad/path.bsq")
if __name__ == "__main__":
unittest.main(verbosity=2)