""" interpolator.py 多进程重构的行为测试(mock-based,无 osgeo 依赖) 验证: 1. 静态结构:模块级函数集合、签名、新参数、_worker_dataset 全局 2. 行为逻辑:_process_one_block 在 mock dataset 上的零像素识别 + 插值 3. 行为逻辑:_interpolate_block_worker 通过模块全局 _worker_dataset 工作 4. 行为逻辑:interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo) 5. 向后兼容:现有 8 个 caller 参数保留默认值不变 如果本机有 osgeo,可补一个真实数据集的 smoke test。 """ import os import sys import ast import types import unittest from unittest.mock import MagicMock, patch # Ensure the project src is on path so we can import the module under test THIS_DIR = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, "..")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) INTERP_PATH = os.path.join( PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py" ) # ============================================================================= # Static structure tests # ============================================================================= class TestInterpolatorStructure(unittest.TestCase): def setUp(self): with open(INTERP_PATH, "r", encoding="utf-8") as f: self.src = f.read() self.tree = ast.parse(self.src) def test_module_level_functions(self): mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)} expected = { "interpolate_pixels", "_interpolate_single_band", "_normalize_interpolation_method", "_read_water_mask_to_array", "_init_worker", "_interpolate_block_worker", "_process_one_block", "interpolate_zero_pixels_batch", } self.assertEqual(mod_funcs, expected) self.assertNotIn("_process_block_with_buffer", mod_funcs) def test_worker_dataset_module_global(self): mod_globals = set() for n in self.tree.body: if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name): mod_globals.add(n.targets[0].id) elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name): mod_globals.add(n.target.id) self.assertIn("_worker_dataset", mod_globals) def test_interpolate_zero_pixels_batch_signature(self): for n in self.tree.body: if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch": args = [a.arg for a in n.args.args] # Backward compat: all 8 existing params + 2 new self.assertEqual( args, [ "img_path", "interpolation_method", "output_path", "water_mask", "deglint_dir", "callback_progress", "block_size", "halo_size", "n_workers", "use_multiprocessing", ], ) defaults = [getattr(d, "value", None) for d in n.args.defaults] self.assertEqual( defaults, ["nearest", None, None, None, None, 1024, 64, None, True], ) return self.fail("interpolate_zero_pixels_batch not found") def test_init_worker_signature(self): for n in self.tree.body: if isinstance(n, ast.FunctionDef) and n.name == "_init_worker": self.assertEqual([a.arg for a in n.args.args], ["img_path"]) return self.fail("_init_worker not found") def test_worker_function_signature(self): for n in self.tree.body: if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker": self.assertEqual([a.arg for a in n.args.args], ["task"]) return self.fail("_interpolate_block_worker not found") def test_process_one_block_signature(self): for n in self.tree.body: if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block": args = [a.arg for a in n.args.args] self.assertEqual( args, [ "dataset", "x0", "y0", "ey0", "ex0", "ey1", "ex1", "row_offset", "col_offset", "inner_h", "inner_w", "mask_segment_ext", "method", ], ) return self.fail("_process_one_block not found") def test_uses_process_pool_executor(self): self.assertIn("ProcessPoolExecutor", self.src) self.assertIn("ProcessPoolExecutor(", self.src) self.assertIn("initializer=_init_worker", self.src) self.assertIn("initargs=(img_path,)", self.src) self.assertIn("GDAL_NUM_THREADS", self.src) def test_dispatch_logic_present(self): # Both serial and parallel paths should be present self.assertIn("if effective_workers <= 1:", self.src) self.assertIn("with ProcessPoolExecutor", self.src) def test_serial_path_uses_process_one_block(self): # Serial branch should call _process_one_block directly (no pickle overhead) self.assertIn( "_process_one_block(\n dataset, *task\n )", self.src, ) def test_worker_path_uses_process_one_block(self): # Worker function should also call _process_one_block (shared core) self.assertIn("_process_one_block(", self.src) # The worker should read from _worker_dataset, not receive dataset in task self.assertIn("_worker_dataset", self.src) # ============================================================================= # Mocked behavior tests (no real osgeo/scipy needed) # ============================================================================= def _make_mock_band(read_data): band = MagicMock() band.ReadAsArray.return_value = read_data return band def _make_mock_dataset(bands_data, n_bands=None): if n_bands is None: n_bands = len(bands_data) ds = MagicMock() ds.RasterCount = n_bands ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1]) return ds def _make_fake_module(name, attrs=None): mod = types.ModuleType(name) for k, v in (attrs or {}).items(): setattr(mod, k, v) return mod def _install_fake_modules(): """Install minimal fakes for scipy + osgeo so the module under test imports.""" import numpy as np # Fake scipy scipy = types.ModuleType("scipy") scipy_ndimage = types.ModuleType("scipy.ndimage") scipy_interp = types.ModuleType("scipy.interpolate") scipy_spatial = types.ModuleType("scipy.spatial") class _FakeTree: def __init__(self, coords): self.coords = coords def query(self, pts): # Return index 0 for all queries import numpy as np n = len(pts) return np.zeros(n, dtype=int), np.zeros(n, dtype=int) def _griddata(coords, values, pts, method="linear", fill_value=0.0): # Trivial: nearest valid value if len(coords) == 0 or len(pts) == 0: return np.zeros(len(pts), dtype=np.float32) # Just return first value for all queries (testing only) return np.full(len(pts), values[0], dtype=np.float32) class _FakeRBF: def __init__(self, coords, values, kernel=None): self.first_value = values[0] def __call__(self, pts): return np.full(len(pts), self.first_value, dtype=np.float32) scipy_spatial.cKDTree = _FakeTree scipy_interp.griddata = _griddata scipy_interp.RBFInterpolator = _FakeRBF scipy.ndimage = scipy_ndimage scipy.interpolate = scipy_interp scipy.spatial = scipy_spatial # Fake osgeo.gdal with the constants the module needs osgeo = types.ModuleType("osgeo") gdal_mod = types.ModuleType("osgeo.gdal") gdal_mod.GA_ReadOnly = 0 gdal_mod.GDT_Float32 = 6 gdal_mod.UseExceptions = MagicMock() def _open(path, mode): return _make_mock_dataset([]) # empty dataset; real tests build their own gdal_mod.Open = _open gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock()) gdal_mod.SetConfigOption = MagicMock() osgeo.gdal = gdal_mod sys.modules["scipy"] = scipy sys.modules["scipy.ndimage"] = scipy_ndimage sys.modules["scipy.interpolate"] = scipy_interp sys.modules["scipy.spatial"] = scipy_spatial sys.modules["osgeo"] = osgeo sys.modules["osgeo.gdal"] = gdal_mod return { "scipy": scipy, "scipy.spatial": scipy_spatial, "scipy.interpolate": scipy_interp, "osgeo.gdal": gdal_mod, } class TestProcessOneBlockMocked(unittest.TestCase): """Verify _process_one_block logic with mocked GDAL/scipy.""" @classmethod def setUpClass(cls): _install_fake_modules() # Import the module under test after installing fakes from src.core.algorithms.interpolation import interpolator cls.interp = interpolator cls.gdal = sys.modules["osgeo.gdal"] def _build_dataset(self, bands): """Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h) returns bands[b-1] (a numpy array of the requested shape). For simplicity we return the full band array regardless of x,y,w,h. """ ds = MagicMock() ds.RasterCount = len(bands) def get_band(b): band = MagicMock() band.ReadAsArray.return_value = bands[b - 1] return band ds.GetRasterBand.side_effect = get_band return ds def test_no_zeros_returns_inner_blocks_unchanged(self): """If no pixels are all-zero, inner blocks should be returned as-is.""" import numpy as np # 2x2 image, 3 bands, all 1s (no zeros anywhere) band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo bands = [band, band * 2, band * 3] ds = self._build_dataset(bands) inner_bands, zero_count = self.interp._process_one_block( dataset=ds, x0=0, y0=0, ey0=0, ex0=0, ey1=4, ex1=4, row_offset=0, col_offset=0, inner_h=2, inner_w=2, mask_segment_ext=None, method="nearest", ) self.assertEqual(zero_count, 0) self.assertEqual(len(inner_bands), 3) for ib, expected_band in zip(inner_bands, bands): self.assertEqual(ib.shape, (2, 2)) np.testing.assert_array_equal(ib, expected_band[:2, :2]) def test_with_zero_pixel_triggers_interpolation(self): """If a pixel is all-zero, _interpolate_single_band should be called.""" import numpy as np # 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s. band1 = np.array([ [0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], ], dtype=np.float32) band2 = band1 * 2 band3 = band1 * 3 ds = self._build_dataset([band1, band2, band3]) # Process the top-left 2x2 block (with halo, 4x4 covers all) inner_bands, zero_count = self.interp._process_one_block( dataset=ds, x0=0, y0=0, ey0=0, ex0=0, ey1=4, ex1=4, row_offset=0, col_offset=0, inner_h=2, inner_w=2, mask_segment_ext=None, method="nearest", ) self.assertGreater(zero_count, 0, "should detect zero pixels") self.assertEqual(len(inner_bands), 3) # Each inner band should be 2x2 for ib in inner_bands: self.assertEqual(ib.shape, (2, 2)) # Our fake nearest interpolation returns valid_values[0]=1 for band1 # So the inner block should be filled with non-zero values self.assertTrue(np.all(inner_bands[0] > 0)) class TestWorkerFunctionMocked(unittest.TestCase): """Verify _interpolate_block_worker uses module-global _worker_dataset.""" @classmethod def setUpClass(cls): _install_fake_modules() from src.core.algorithms.interpolation import interpolator cls.interp = interpolator def test_worker_uses_module_global_dataset(self): import numpy as np # Set the module global band = np.array([ [0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], ], dtype=np.float32) ds = self._build_dataset([band, band * 2]) self.interp._worker_dataset = ds try: task = ( 0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1 0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w None, "nearest", # mask_segment_ext, method ) x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task) self.assertIsNone(error) self.assertEqual((x0, y0), (0, 0)) self.assertGreater(zero_count, 0) self.assertIsNotNone(inner_bands) self.assertEqual(len(inner_bands), 2) finally: self.interp._worker_dataset = None def test_worker_returns_error_if_dataset_uninitialized(self): self.interp._worker_dataset = None task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest") x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task) self.assertIsNotNone(error) self.assertIn("not initialized", error) self.assertIsNone(inner_bands) def _build_dataset(self, bands): ds = MagicMock() ds.RasterCount = len(bands) def get_band(b): band = MagicMock() band.ReadAsArray.return_value = bands[b - 1] return band ds.GetRasterBand.side_effect = get_band return ds class TestInitWorkerMocked(unittest.TestCase): """Verify _init_worker sets module global and config option.""" @classmethod def setUpClass(cls): _install_fake_modules() from src.core.algorithms.interpolation import interpolator cls.interp = interpolator cls.gdal_mod = sys.modules["osgeo.gdal"] def setUp(self): self.interp._worker_dataset = None def tearDown(self): self.interp._worker_dataset = None def test_init_worker_opens_and_caches(self): fake_ds = MagicMock() with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \ patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock: self.interp._init_worker("/fake/path.bsq") self.assertIs(self.interp._worker_dataset, fake_ds) open_mock.assert_called_once_with("/fake/path.bsq", 0) # Should have set GDAL_NUM_THREADS=1 cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1") def test_init_worker_raises_if_open_fails(self): with patch.object(self.gdal_mod, "Open", return_value=None): with self.assertRaises(RuntimeError): self.interp._init_worker("/bad/path.bsq") if __name__ == "__main__": unittest.main(verbosity=2)