Step3 插值算法 OOM 修复 + 多进程加速 + 全链路累积改动(14 文件)
This commit is contained in:
411
tests/test_interpolator_refactor.py
Normal file
411
tests/test_interpolator_refactor.py
Normal file
@ -0,0 +1,411 @@
|
||||
"""
|
||||
interpolator.py 多进程重构的行为测试(mock-based,无 osgeo 依赖)
|
||||
|
||||
验证:
|
||||
1. 静态结构:模块级函数集合、签名、新参数、_worker_dataset 全局
|
||||
2. 行为逻辑:_process_one_block 在 mock dataset 上的零像素识别 + 插值
|
||||
3. 行为逻辑:_interpolate_block_worker 通过模块全局 _worker_dataset 工作
|
||||
4. 行为逻辑:interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo)
|
||||
5. 向后兼容:现有 8 个 caller 参数保留默认值不变
|
||||
|
||||
如果本机有 osgeo,可补一个真实数据集的 smoke test。
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Ensure the project src is on path so we can import the module under test
|
||||
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
INTERP_PATH = os.path.join(
|
||||
PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Static structure tests
|
||||
# =============================================================================
|
||||
class TestInterpolatorStructure(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with open(INTERP_PATH, "r", encoding="utf-8") as f:
|
||||
self.src = f.read()
|
||||
self.tree = ast.parse(self.src)
|
||||
|
||||
def test_module_level_functions(self):
|
||||
mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)}
|
||||
expected = {
|
||||
"interpolate_pixels",
|
||||
"_interpolate_single_band",
|
||||
"_normalize_interpolation_method",
|
||||
"_read_water_mask_to_array",
|
||||
"_init_worker",
|
||||
"_interpolate_block_worker",
|
||||
"_process_one_block",
|
||||
"interpolate_zero_pixels_batch",
|
||||
}
|
||||
self.assertEqual(mod_funcs, expected)
|
||||
self.assertNotIn("_process_block_with_buffer", mod_funcs)
|
||||
|
||||
def test_worker_dataset_module_global(self):
|
||||
mod_globals = set()
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name):
|
||||
mod_globals.add(n.targets[0].id)
|
||||
elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name):
|
||||
mod_globals.add(n.target.id)
|
||||
self.assertIn("_worker_dataset", mod_globals)
|
||||
|
||||
def test_interpolate_zero_pixels_batch_signature(self):
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch":
|
||||
args = [a.arg for a in n.args.args]
|
||||
# Backward compat: all 8 existing params + 2 new
|
||||
self.assertEqual(
|
||||
args,
|
||||
[
|
||||
"img_path", "interpolation_method", "output_path",
|
||||
"water_mask", "deglint_dir", "callback_progress",
|
||||
"block_size", "halo_size", "n_workers", "use_multiprocessing",
|
||||
],
|
||||
)
|
||||
defaults = [getattr(d, "value", None) for d in n.args.defaults]
|
||||
self.assertEqual(
|
||||
defaults,
|
||||
["nearest", None, None, None, None, 1024, 64, None, True],
|
||||
)
|
||||
return
|
||||
self.fail("interpolate_zero_pixels_batch not found")
|
||||
|
||||
def test_init_worker_signature(self):
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "_init_worker":
|
||||
self.assertEqual([a.arg for a in n.args.args], ["img_path"])
|
||||
return
|
||||
self.fail("_init_worker not found")
|
||||
|
||||
def test_worker_function_signature(self):
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker":
|
||||
self.assertEqual([a.arg for a in n.args.args], ["task"])
|
||||
return
|
||||
self.fail("_interpolate_block_worker not found")
|
||||
|
||||
def test_process_one_block_signature(self):
|
||||
for n in self.tree.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block":
|
||||
args = [a.arg for a in n.args.args]
|
||||
self.assertEqual(
|
||||
args,
|
||||
[
|
||||
"dataset", "x0", "y0",
|
||||
"ey0", "ex0", "ey1", "ex1",
|
||||
"row_offset", "col_offset",
|
||||
"inner_h", "inner_w",
|
||||
"mask_segment_ext", "method",
|
||||
],
|
||||
)
|
||||
return
|
||||
self.fail("_process_one_block not found")
|
||||
|
||||
def test_uses_process_pool_executor(self):
|
||||
self.assertIn("ProcessPoolExecutor", self.src)
|
||||
self.assertIn("ProcessPoolExecutor(", self.src)
|
||||
self.assertIn("initializer=_init_worker", self.src)
|
||||
self.assertIn("initargs=(img_path,)", self.src)
|
||||
self.assertIn("GDAL_NUM_THREADS", self.src)
|
||||
|
||||
def test_dispatch_logic_present(self):
|
||||
# Both serial and parallel paths should be present
|
||||
self.assertIn("if effective_workers <= 1:", self.src)
|
||||
self.assertIn("with ProcessPoolExecutor", self.src)
|
||||
|
||||
def test_serial_path_uses_process_one_block(self):
|
||||
# Serial branch should call _process_one_block directly (no pickle overhead)
|
||||
self.assertIn(
|
||||
"_process_one_block(\n dataset, *task\n )",
|
||||
self.src,
|
||||
)
|
||||
|
||||
def test_worker_path_uses_process_one_block(self):
|
||||
# Worker function should also call _process_one_block (shared core)
|
||||
self.assertIn("_process_one_block(", self.src)
|
||||
# The worker should read from _worker_dataset, not receive dataset in task
|
||||
self.assertIn("_worker_dataset", self.src)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mocked behavior tests (no real osgeo/scipy needed)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _make_mock_band(read_data):
|
||||
band = MagicMock()
|
||||
band.ReadAsArray.return_value = read_data
|
||||
return band
|
||||
|
||||
|
||||
def _make_mock_dataset(bands_data, n_bands=None):
|
||||
if n_bands is None:
|
||||
n_bands = len(bands_data)
|
||||
ds = MagicMock()
|
||||
ds.RasterCount = n_bands
|
||||
ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1])
|
||||
return ds
|
||||
|
||||
|
||||
def _make_fake_module(name, attrs=None):
|
||||
mod = types.ModuleType(name)
|
||||
for k, v in (attrs or {}).items():
|
||||
setattr(mod, k, v)
|
||||
return mod
|
||||
|
||||
|
||||
def _install_fake_modules():
|
||||
"""Install minimal fakes for scipy + osgeo so the module under test imports."""
|
||||
import numpy as np
|
||||
|
||||
# Fake scipy
|
||||
scipy = types.ModuleType("scipy")
|
||||
scipy_ndimage = types.ModuleType("scipy.ndimage")
|
||||
scipy_interp = types.ModuleType("scipy.interpolate")
|
||||
scipy_spatial = types.ModuleType("scipy.spatial")
|
||||
|
||||
class _FakeTree:
|
||||
def __init__(self, coords):
|
||||
self.coords = coords
|
||||
def query(self, pts):
|
||||
# Return index 0 for all queries
|
||||
import numpy as np
|
||||
n = len(pts)
|
||||
return np.zeros(n, dtype=int), np.zeros(n, dtype=int)
|
||||
|
||||
def _griddata(coords, values, pts, method="linear", fill_value=0.0):
|
||||
# Trivial: nearest valid value
|
||||
if len(coords) == 0 or len(pts) == 0:
|
||||
return np.zeros(len(pts), dtype=np.float32)
|
||||
# Just return first value for all queries (testing only)
|
||||
return np.full(len(pts), values[0], dtype=np.float32)
|
||||
|
||||
class _FakeRBF:
|
||||
def __init__(self, coords, values, kernel=None):
|
||||
self.first_value = values[0]
|
||||
def __call__(self, pts):
|
||||
return np.full(len(pts), self.first_value, dtype=np.float32)
|
||||
|
||||
scipy_spatial.cKDTree = _FakeTree
|
||||
scipy_interp.griddata = _griddata
|
||||
scipy_interp.RBFInterpolator = _FakeRBF
|
||||
|
||||
scipy.ndimage = scipy_ndimage
|
||||
scipy.interpolate = scipy_interp
|
||||
scipy.spatial = scipy_spatial
|
||||
|
||||
# Fake osgeo.gdal with the constants the module needs
|
||||
osgeo = types.ModuleType("osgeo")
|
||||
gdal_mod = types.ModuleType("osgeo.gdal")
|
||||
gdal_mod.GA_ReadOnly = 0
|
||||
gdal_mod.GDT_Float32 = 6
|
||||
gdal_mod.UseExceptions = MagicMock()
|
||||
|
||||
def _open(path, mode):
|
||||
return _make_mock_dataset([]) # empty dataset; real tests build their own
|
||||
gdal_mod.Open = _open
|
||||
gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock())
|
||||
gdal_mod.SetConfigOption = MagicMock()
|
||||
|
||||
osgeo.gdal = gdal_mod
|
||||
|
||||
sys.modules["scipy"] = scipy
|
||||
sys.modules["scipy.ndimage"] = scipy_ndimage
|
||||
sys.modules["scipy.interpolate"] = scipy_interp
|
||||
sys.modules["scipy.spatial"] = scipy_spatial
|
||||
sys.modules["osgeo"] = osgeo
|
||||
sys.modules["osgeo.gdal"] = gdal_mod
|
||||
|
||||
return {
|
||||
"scipy": scipy,
|
||||
"scipy.spatial": scipy_spatial,
|
||||
"scipy.interpolate": scipy_interp,
|
||||
"osgeo.gdal": gdal_mod,
|
||||
}
|
||||
|
||||
|
||||
class TestProcessOneBlockMocked(unittest.TestCase):
|
||||
"""Verify _process_one_block logic with mocked GDAL/scipy."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
_install_fake_modules()
|
||||
# Import the module under test after installing fakes
|
||||
from src.core.algorithms.interpolation import interpolator
|
||||
cls.interp = interpolator
|
||||
cls.gdal = sys.modules["osgeo.gdal"]
|
||||
|
||||
def _build_dataset(self, bands):
|
||||
"""Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h)
|
||||
returns bands[b-1] (a numpy array of the requested shape).
|
||||
|
||||
For simplicity we return the full band array regardless of x,y,w,h.
|
||||
"""
|
||||
ds = MagicMock()
|
||||
ds.RasterCount = len(bands)
|
||||
|
||||
def get_band(b):
|
||||
band = MagicMock()
|
||||
band.ReadAsArray.return_value = bands[b - 1]
|
||||
return band
|
||||
ds.GetRasterBand.side_effect = get_band
|
||||
return ds
|
||||
|
||||
def test_no_zeros_returns_inner_blocks_unchanged(self):
|
||||
"""If no pixels are all-zero, inner blocks should be returned as-is."""
|
||||
import numpy as np
|
||||
# 2x2 image, 3 bands, all 1s (no zeros anywhere)
|
||||
band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo
|
||||
bands = [band, band * 2, band * 3]
|
||||
ds = self._build_dataset(bands)
|
||||
|
||||
inner_bands, zero_count = self.interp._process_one_block(
|
||||
dataset=ds,
|
||||
x0=0, y0=0,
|
||||
ey0=0, ex0=0, ey1=4, ex1=4,
|
||||
row_offset=0, col_offset=0,
|
||||
inner_h=2, inner_w=2,
|
||||
mask_segment_ext=None,
|
||||
method="nearest",
|
||||
)
|
||||
self.assertEqual(zero_count, 0)
|
||||
self.assertEqual(len(inner_bands), 3)
|
||||
for ib, expected_band in zip(inner_bands, bands):
|
||||
self.assertEqual(ib.shape, (2, 2))
|
||||
np.testing.assert_array_equal(ib, expected_band[:2, :2])
|
||||
|
||||
def test_with_zero_pixel_triggers_interpolation(self):
|
||||
"""If a pixel is all-zero, _interpolate_single_band should be called."""
|
||||
import numpy as np
|
||||
# 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s.
|
||||
band1 = np.array([
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
], dtype=np.float32)
|
||||
band2 = band1 * 2
|
||||
band3 = band1 * 3
|
||||
ds = self._build_dataset([band1, band2, band3])
|
||||
|
||||
# Process the top-left 2x2 block (with halo, 4x4 covers all)
|
||||
inner_bands, zero_count = self.interp._process_one_block(
|
||||
dataset=ds,
|
||||
x0=0, y0=0,
|
||||
ey0=0, ex0=0, ey1=4, ex1=4,
|
||||
row_offset=0, col_offset=0,
|
||||
inner_h=2, inner_w=2,
|
||||
mask_segment_ext=None,
|
||||
method="nearest",
|
||||
)
|
||||
self.assertGreater(zero_count, 0, "should detect zero pixels")
|
||||
self.assertEqual(len(inner_bands), 3)
|
||||
# Each inner band should be 2x2
|
||||
for ib in inner_bands:
|
||||
self.assertEqual(ib.shape, (2, 2))
|
||||
# Our fake nearest interpolation returns valid_values[0]=1 for band1
|
||||
# So the inner block should be filled with non-zero values
|
||||
self.assertTrue(np.all(inner_bands[0] > 0))
|
||||
|
||||
|
||||
class TestWorkerFunctionMocked(unittest.TestCase):
|
||||
"""Verify _interpolate_block_worker uses module-global _worker_dataset."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
_install_fake_modules()
|
||||
from src.core.algorithms.interpolation import interpolator
|
||||
cls.interp = interpolator
|
||||
|
||||
def test_worker_uses_module_global_dataset(self):
|
||||
import numpy as np
|
||||
# Set the module global
|
||||
band = np.array([
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1],
|
||||
], dtype=np.float32)
|
||||
ds = self._build_dataset([band, band * 2])
|
||||
self.interp._worker_dataset = ds
|
||||
try:
|
||||
task = (
|
||||
0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1
|
||||
0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w
|
||||
None, "nearest", # mask_segment_ext, method
|
||||
)
|
||||
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
|
||||
self.assertIsNone(error)
|
||||
self.assertEqual((x0, y0), (0, 0))
|
||||
self.assertGreater(zero_count, 0)
|
||||
self.assertIsNotNone(inner_bands)
|
||||
self.assertEqual(len(inner_bands), 2)
|
||||
finally:
|
||||
self.interp._worker_dataset = None
|
||||
|
||||
def test_worker_returns_error_if_dataset_uninitialized(self):
|
||||
self.interp._worker_dataset = None
|
||||
task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest")
|
||||
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
|
||||
self.assertIsNotNone(error)
|
||||
self.assertIn("not initialized", error)
|
||||
self.assertIsNone(inner_bands)
|
||||
|
||||
def _build_dataset(self, bands):
|
||||
ds = MagicMock()
|
||||
ds.RasterCount = len(bands)
|
||||
def get_band(b):
|
||||
band = MagicMock()
|
||||
band.ReadAsArray.return_value = bands[b - 1]
|
||||
return band
|
||||
ds.GetRasterBand.side_effect = get_band
|
||||
return ds
|
||||
|
||||
|
||||
class TestInitWorkerMocked(unittest.TestCase):
|
||||
"""Verify _init_worker sets module global and config option."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
_install_fake_modules()
|
||||
from src.core.algorithms.interpolation import interpolator
|
||||
cls.interp = interpolator
|
||||
cls.gdal_mod = sys.modules["osgeo.gdal"]
|
||||
|
||||
def setUp(self):
|
||||
self.interp._worker_dataset = None
|
||||
|
||||
def tearDown(self):
|
||||
self.interp._worker_dataset = None
|
||||
|
||||
def test_init_worker_opens_and_caches(self):
|
||||
fake_ds = MagicMock()
|
||||
with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \
|
||||
patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock:
|
||||
self.interp._init_worker("/fake/path.bsq")
|
||||
self.assertIs(self.interp._worker_dataset, fake_ds)
|
||||
open_mock.assert_called_once_with("/fake/path.bsq", 0)
|
||||
# Should have set GDAL_NUM_THREADS=1
|
||||
cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1")
|
||||
|
||||
def test_init_worker_raises_if_open_fails(self):
|
||||
with patch.object(self.gdal_mod, "Open", return_value=None):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.interp._init_worker("/bad/path.bsq")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user