Files
WQ_GUI/tests/test_interpolator_refactor.py

412 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
interpolator.py 多进程重构的行为测试mock-based无 osgeo 依赖)
验证:
1. 静态结构模块级函数集合、签名、新参数、_worker_dataset 全局
2. 行为逻辑_process_one_block 在 mock dataset 上的零像素识别 + 插值
3. 行为逻辑_interpolate_block_worker 通过模块全局 _worker_dataset 工作
4. 行为逻辑interpolate_zero_pixels_batch 的串行路径(不依赖 osgeo
5. 向后兼容:现有 8 个 caller 参数保留默认值不变
如果本机有 osgeo可补一个真实数据集的 smoke test。
"""
import os
import sys
import ast
import types
import unittest
from unittest.mock import MagicMock, patch
# Ensure the project src is on path so we can import the module under test
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(THIS_DIR, ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
INTERP_PATH = os.path.join(
PROJECT_ROOT, "src", "core", "algorithms", "interpolation", "interpolator.py"
)
# =============================================================================
# Static structure tests
# =============================================================================
class TestInterpolatorStructure(unittest.TestCase):
def setUp(self):
with open(INTERP_PATH, "r", encoding="utf-8") as f:
self.src = f.read()
self.tree = ast.parse(self.src)
def test_module_level_functions(self):
mod_funcs = {n.name for n in self.tree.body if isinstance(n, ast.FunctionDef)}
expected = {
"interpolate_pixels",
"_interpolate_single_band",
"_normalize_interpolation_method",
"_read_water_mask_to_array",
"_init_worker",
"_interpolate_block_worker",
"_process_one_block",
"interpolate_zero_pixels_batch",
}
self.assertEqual(mod_funcs, expected)
self.assertNotIn("_process_block_with_buffer", mod_funcs)
def test_worker_dataset_module_global(self):
mod_globals = set()
for n in self.tree.body:
if isinstance(n, ast.Assign) and isinstance(n.targets[0], ast.Name):
mod_globals.add(n.targets[0].id)
elif isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name):
mod_globals.add(n.target.id)
self.assertIn("_worker_dataset", mod_globals)
def test_interpolate_zero_pixels_batch_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "interpolate_zero_pixels_batch":
args = [a.arg for a in n.args.args]
# Backward compat: all 8 existing params + 2 new
self.assertEqual(
args,
[
"img_path", "interpolation_method", "output_path",
"water_mask", "deglint_dir", "callback_progress",
"block_size", "halo_size", "n_workers", "use_multiprocessing",
],
)
defaults = [getattr(d, "value", None) for d in n.args.defaults]
self.assertEqual(
defaults,
["nearest", None, None, None, None, 1024, 64, None, True],
)
return
self.fail("interpolate_zero_pixels_batch not found")
def test_init_worker_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_init_worker":
self.assertEqual([a.arg for a in n.args.args], ["img_path"])
return
self.fail("_init_worker not found")
def test_worker_function_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_interpolate_block_worker":
self.assertEqual([a.arg for a in n.args.args], ["task"])
return
self.fail("_interpolate_block_worker not found")
def test_process_one_block_signature(self):
for n in self.tree.body:
if isinstance(n, ast.FunctionDef) and n.name == "_process_one_block":
args = [a.arg for a in n.args.args]
self.assertEqual(
args,
[
"dataset", "x0", "y0",
"ey0", "ex0", "ey1", "ex1",
"row_offset", "col_offset",
"inner_h", "inner_w",
"mask_segment_ext", "method",
],
)
return
self.fail("_process_one_block not found")
def test_uses_process_pool_executor(self):
self.assertIn("ProcessPoolExecutor", self.src)
self.assertIn("ProcessPoolExecutor(", self.src)
self.assertIn("initializer=_init_worker", self.src)
self.assertIn("initargs=(img_path,)", self.src)
self.assertIn("GDAL_NUM_THREADS", self.src)
def test_dispatch_logic_present(self):
# Both serial and parallel paths should be present
self.assertIn("if effective_workers <= 1:", self.src)
self.assertIn("with ProcessPoolExecutor", self.src)
def test_serial_path_uses_process_one_block(self):
# Serial branch should call _process_one_block directly (no pickle overhead)
self.assertIn(
"_process_one_block(\n dataset, *task\n )",
self.src,
)
def test_worker_path_uses_process_one_block(self):
# Worker function should also call _process_one_block (shared core)
self.assertIn("_process_one_block(", self.src)
# The worker should read from _worker_dataset, not receive dataset in task
self.assertIn("_worker_dataset", self.src)
# =============================================================================
# Mocked behavior tests (no real osgeo/scipy needed)
# =============================================================================
def _make_mock_band(read_data):
band = MagicMock()
band.ReadAsArray.return_value = read_data
return band
def _make_mock_dataset(bands_data, n_bands=None):
if n_bands is None:
n_bands = len(bands_data)
ds = MagicMock()
ds.RasterCount = n_bands
ds.GetRasterBand.side_effect = lambda b: _make_mock_band(bands_data[b - 1])
return ds
def _make_fake_module(name, attrs=None):
mod = types.ModuleType(name)
for k, v in (attrs or {}).items():
setattr(mod, k, v)
return mod
def _install_fake_modules():
"""Install minimal fakes for scipy + osgeo so the module under test imports."""
import numpy as np
# Fake scipy
scipy = types.ModuleType("scipy")
scipy_ndimage = types.ModuleType("scipy.ndimage")
scipy_interp = types.ModuleType("scipy.interpolate")
scipy_spatial = types.ModuleType("scipy.spatial")
class _FakeTree:
def __init__(self, coords):
self.coords = coords
def query(self, pts):
# Return index 0 for all queries
import numpy as np
n = len(pts)
return np.zeros(n, dtype=int), np.zeros(n, dtype=int)
def _griddata(coords, values, pts, method="linear", fill_value=0.0):
# Trivial: nearest valid value
if len(coords) == 0 or len(pts) == 0:
return np.zeros(len(pts), dtype=np.float32)
# Just return first value for all queries (testing only)
return np.full(len(pts), values[0], dtype=np.float32)
class _FakeRBF:
def __init__(self, coords, values, kernel=None):
self.first_value = values[0]
def __call__(self, pts):
return np.full(len(pts), self.first_value, dtype=np.float32)
scipy_spatial.cKDTree = _FakeTree
scipy_interp.griddata = _griddata
scipy_interp.RBFInterpolator = _FakeRBF
scipy.ndimage = scipy_ndimage
scipy.interpolate = scipy_interp
scipy.spatial = scipy_spatial
# Fake osgeo.gdal with the constants the module needs
osgeo = types.ModuleType("osgeo")
gdal_mod = types.ModuleType("osgeo.gdal")
gdal_mod.GA_ReadOnly = 0
gdal_mod.GDT_Float32 = 6
gdal_mod.UseExceptions = MagicMock()
def _open(path, mode):
return _make_mock_dataset([]) # empty dataset; real tests build their own
gdal_mod.Open = _open
gdal_mod.GetDriverByName = MagicMock(return_value=MagicMock())
gdal_mod.SetConfigOption = MagicMock()
osgeo.gdal = gdal_mod
sys.modules["scipy"] = scipy
sys.modules["scipy.ndimage"] = scipy_ndimage
sys.modules["scipy.interpolate"] = scipy_interp
sys.modules["scipy.spatial"] = scipy_spatial
sys.modules["osgeo"] = osgeo
sys.modules["osgeo.gdal"] = gdal_mod
return {
"scipy": scipy,
"scipy.spatial": scipy_spatial,
"scipy.interpolate": scipy_interp,
"osgeo.gdal": gdal_mod,
}
class TestProcessOneBlockMocked(unittest.TestCase):
"""Verify _process_one_block logic with mocked GDAL/scipy."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
# Import the module under test after installing fakes
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
cls.gdal = sys.modules["osgeo.gdal"]
def _build_dataset(self, bands):
"""Build a mock dataset where GetRasterBand(b).ReadAsArray(x,y,w,h)
returns bands[b-1] (a numpy array of the requested shape).
For simplicity we return the full band array regardless of x,y,w,h.
"""
ds = MagicMock()
ds.RasterCount = len(bands)
def get_band(b):
band = MagicMock()
band.ReadAsArray.return_value = bands[b - 1]
return band
ds.GetRasterBand.side_effect = get_band
return ds
def test_no_zeros_returns_inner_blocks_unchanged(self):
"""If no pixels are all-zero, inner blocks should be returned as-is."""
import numpy as np
# 2x2 image, 3 bands, all 1s (no zeros anywhere)
band = np.ones((4, 4), dtype=np.float32) # 2 inner + 2 halo
bands = [band, band * 2, band * 3]
ds = self._build_dataset(bands)
inner_bands, zero_count = self.interp._process_one_block(
dataset=ds,
x0=0, y0=0,
ey0=0, ex0=0, ey1=4, ex1=4,
row_offset=0, col_offset=0,
inner_h=2, inner_w=2,
mask_segment_ext=None,
method="nearest",
)
self.assertEqual(zero_count, 0)
self.assertEqual(len(inner_bands), 3)
for ib, expected_band in zip(inner_bands, bands):
self.assertEqual(ib.shape, (2, 2))
np.testing.assert_array_equal(ib, expected_band[:2, :2])
def test_with_zero_pixel_triggers_interpolation(self):
"""If a pixel is all-zero, _interpolate_single_band should be called."""
import numpy as np
# 4x4 image, 3 bands. Top-left 2x2 is zeros, rest is 1s.
band1 = np.array([
[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
], dtype=np.float32)
band2 = band1 * 2
band3 = band1 * 3
ds = self._build_dataset([band1, band2, band3])
# Process the top-left 2x2 block (with halo, 4x4 covers all)
inner_bands, zero_count = self.interp._process_one_block(
dataset=ds,
x0=0, y0=0,
ey0=0, ex0=0, ey1=4, ex1=4,
row_offset=0, col_offset=0,
inner_h=2, inner_w=2,
mask_segment_ext=None,
method="nearest",
)
self.assertGreater(zero_count, 0, "should detect zero pixels")
self.assertEqual(len(inner_bands), 3)
# Each inner band should be 2x2
for ib in inner_bands:
self.assertEqual(ib.shape, (2, 2))
# Our fake nearest interpolation returns valid_values[0]=1 for band1
# So the inner block should be filled with non-zero values
self.assertTrue(np.all(inner_bands[0] > 0))
class TestWorkerFunctionMocked(unittest.TestCase):
"""Verify _interpolate_block_worker uses module-global _worker_dataset."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
def test_worker_uses_module_global_dataset(self):
import numpy as np
# Set the module global
band = np.array([
[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
], dtype=np.float32)
ds = self._build_dataset([band, band * 2])
self.interp._worker_dataset = ds
try:
task = (
0, 0, 0, 0, 4, 4, # x0, y0, ey0, ex0, ey1, ex1
0, 0, 2, 2, # row_offset, col_offset, inner_h, inner_w
None, "nearest", # mask_segment_ext, method
)
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
self.assertIsNone(error)
self.assertEqual((x0, y0), (0, 0))
self.assertGreater(zero_count, 0)
self.assertIsNotNone(inner_bands)
self.assertEqual(len(inner_bands), 2)
finally:
self.interp._worker_dataset = None
def test_worker_returns_error_if_dataset_uninitialized(self):
self.interp._worker_dataset = None
task = (0, 0, 0, 0, 2, 2, 0, 0, 2, 2, None, "nearest")
x0, y0, inner_bands, zero_count, error = self.interp._interpolate_block_worker(task)
self.assertIsNotNone(error)
self.assertIn("not initialized", error)
self.assertIsNone(inner_bands)
def _build_dataset(self, bands):
ds = MagicMock()
ds.RasterCount = len(bands)
def get_band(b):
band = MagicMock()
band.ReadAsArray.return_value = bands[b - 1]
return band
ds.GetRasterBand.side_effect = get_band
return ds
class TestInitWorkerMocked(unittest.TestCase):
"""Verify _init_worker sets module global and config option."""
@classmethod
def setUpClass(cls):
_install_fake_modules()
from src.core.algorithms.interpolation import interpolator
cls.interp = interpolator
cls.gdal_mod = sys.modules["osgeo.gdal"]
def setUp(self):
self.interp._worker_dataset = None
def tearDown(self):
self.interp._worker_dataset = None
def test_init_worker_opens_and_caches(self):
fake_ds = MagicMock()
with patch.object(self.gdal_mod, "Open", return_value=fake_ds) as open_mock, \
patch.object(self.gdal_mod, "SetConfigOption") as cfg_mock:
self.interp._init_worker("/fake/path.bsq")
self.assertIs(self.interp._worker_dataset, fake_ds)
open_mock.assert_called_once_with("/fake/path.bsq", 0)
# Should have set GDAL_NUM_THREADS=1
cfg_mock.assert_any_call("GDAL_NUM_THREADS", "1")
def test_init_worker_raises_if_open_fails(self):
with patch.object(self.gdal_mod, "Open", return_value=None):
with self.assertRaises(RuntimeError):
self.interp._init_worker("/bad/path.bsq")
if __name__ == "__main__":
unittest.main(verbosity=2)