From 027981e9a6ca74379e2288a2f7759078af7dc569 Mon Sep 17 00:00:00 2001 From: DXC Date: Tue, 16 Jun 2026 15:15:10 +0800 Subject: [PATCH] =?UTF-8?q?ContentMapper=20=E8=BE=B9=E7=95=8C=E8=AF=BB?= =?UTF-8?q?=E5=8F=96=E6=94=AF=E6=8C=81=E6=A0=85=E6=A0=BC=E6=B0=B4=E6=8E=A9?= =?UTF-8?q?=E8=86=9C=EF=BC=88.dat/.bsq/.tif/.tiff/.img=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/postprocessing/map.py | 78 ++++++++++++++++++++++++++--- tests/smoke_step10_path_override.py | 42 ++++++++++++++++ 2 files changed, 113 insertions(+), 7 deletions(-) diff --git a/src/postprocessing/map.py b/src/postprocessing/map.py index b542c8c..458b12c 100644 --- a/src/postprocessing/map.py +++ b/src/postprocessing/map.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist from scipy.spatial import ConvexHull from shapely.geometry import Point, Polygon import rasterio -from rasterio.features import geometry_mask +from rasterio.features import geometry_mask, shapes from rasterio import windows from rasterio.warp import calculate_default_transform, reproject, Resampling try: @@ -785,18 +785,82 @@ class ContentMapper: return gdf def read_boundary_shapefile(self, shp_file): - """读取边界shapefile""" - print("正在读取边界文件...") - boundary = gpd.read_file(shp_file) + """读取边界/掩膜文件(同时支持矢量 .shp 与栅格 .dat/.bsq/.tif/.tiff)。 + + - .shp → gpd.read_file 读取矢量边界(保持原行为) + - .dat/.bsq/.tif/.tiff/.img → rasterio 读取栅格水掩膜 → rasterio.features.shapes + 矢量化成水体多边形 → gpd.GeoDataFrame 返回 + + 下游 create_interpolation_grid / create_content_map / visualize_raster + 始终接收 GeoDataFrame,无需任何改动。 + """ + print("正在读取边界/掩膜文件...") + suffix = Path(shp_file).suffix.lower() + if suffix in (".shp",): + boundary = gpd.read_file(shp_file) + elif suffix in (".dat", ".bsq", ".tif", ".tiff", ".img"): + boundary = self._raster_to_boundary_gdf(shp_file) + else: + raise ValueError( + f"不支持的边界/掩膜文件格式: {suffix}(仅支持 .shp / .dat / .bsq / .tif / .img)" + ) + + if len(boundary) == 0: + raise ValueError( + f"边界/掩膜 {shp_file} 矢量化后为空(栅格格式请确认 .dat 包含水体像元 > 0)" + ) # 确保边界文件使用目标投影坐标系 - if boundary.crs != self.output_crs: - print(f"正在转换边界文件坐标系到 {self.output_crs}...") + if boundary.crs is not None and boundary.crs != self.output_crs: + print(f"正在转换边界/掩膜坐标系到 {self.output_crs}...") boundary = boundary.to_crs(self.output_crs) - print(f"边界文件包含 {len(boundary)} 个要素") + print(f"边界/掩膜文件包含 {len(boundary)} 个要素") return boundary + def _raster_to_boundary_gdf(self, raster_path): + """把栅格二值水掩膜(.dat/.bsq/.tif/.tiff)矢量化成水体多边形 GeoDataFrame。 + + 修复 Step 11 接收 step1 产出 .dat 水掩膜的兼容性: + - rasterio.open 读 band 1(0=非水, 1/任意正数=水) + - rasterio.features.shapes 矢量化成多边形 + - 收集所有 val=1 的多边形 → gpd.GeoDataFrame + """ + try: + from shapely.geometry import shape as _shapely_shape + except ImportError as e: + raise ImportError( + "栅格掩膜矢量化需要 shapely(geopandas 自带)。原始错误: " + str(e) + ) + + with rasterio.open(raster_path) as src: + data = src.read(1) + transform = src.transform + crs = src.crs + + # 二值化:>0 视为水 + mask_uint8 = (data > 0).astype(np.uint8) + if int(mask_uint8.sum()) == 0: + raise ValueError( + f"栅格掩膜 {raster_path} 中无水体像元(>0),无法矢量化" + ) + + # 矢量化:shapes 返回 (geom_dict, value) 迭代器 + geoms = [] + for geom_dict, val in shapes(mask_uint8, mask=mask_uint8.astype(bool), transform=transform): + if int(val) == 1: + geoms.append(_shapely_shape(geom_dict)) + + if not geoms: + raise ValueError(f"栅格掩膜 {raster_path} 矢量化后无有效水体多边形") + + gdf = gpd.GeoDataFrame(geometry=geoms, crs=crs) + print( + f"栅格掩膜 {Path(raster_path).name} 矢量化完成: " + f"{int(mask_uint8.sum())} 个水体像元 → {len(gdf)} 个多边形 (CRS={crs})" + ) + return gdf + def _identify_edge_points(self, points_gdf): """ 识别边缘采样点(使用凸包方法) diff --git a/tests/smoke_step10_path_override.py b/tests/smoke_step10_path_override.py index 339bf51..05190b0 100644 --- a/tests/smoke_step10_path_override.py +++ b/tests/smoke_step10_path_override.py @@ -18,6 +18,7 @@ ROOT = Path(__file__).resolve().parents[1] PIPELINE_FILE = ROOT / "src" / "core" / "water_quality_inversion_pipeline_GUI.py" PANEL_FILE = ROOT / "src" / "gui" / "panels" / "step11_map_panel.py" RESOLVER_FILE = ROOT / "src" / "gui" / "panels" / "_step_path_resolver.py" +MAP_FILE = ROOT / "src" / "postprocessing" / "map.py" def test_step10_map_forced_override(): @@ -101,6 +102,46 @@ def test_panel_guard_does_not_overwrite_existing(): print("✅ step11 panel 4.5 段遵守 '非空值不覆盖' 原则") +def test_map_supports_raster_mask_formats(): + """验证 ContentMapper.read_boundary_shapefile 内部已支持栅格格式分发(.dat/.bsq/.tif/.tiff/.img)。 + + 之前 bug:4.5 段成功把 .dat 填入 boundary_file,但 ContentMapper 内部 + gpd.read_file(.dat) 直接报错。修复后 ContentMapper 内部按后缀分发: + - .shp → 矢量(保持原行为) + - .dat/.bsq/.tif/.tiff/.img → rasterio.features.shapes 矢量化成 GeoDataFrame + """ + text = MAP_FILE.read_text(encoding="utf-8") + # 找 def read_boundary_shapefile + m = re.search( + r"def read_boundary_shapefile\(self,[^\)]*\)[^\n]*:\n(.*?)(?=\n def |\nclass |\Z)", + text, re.DOTALL, + ) + assert m, "找不到 read_boundary_shapefile 方法" + body = m.group(1) + + # 关键标记:format dispatch + assert ".dat" in body, "read_boundary_shapefile 应支持 .dat 栅格" + assert ".bsq" in body, "read_boundary_shapefile 应支持 .bsq 栅格" + assert ".tif" in body, "read_boundary_shapefile 应支持 .tif 栅格" + assert "gpd.read_file(shp_file)" in body, ".shp 分支应保留 gpd.read_file 矢量读取" + assert "rasterio.features.shapes" in body or "from rasterio.features import" in text, \ + "栅格分支应使用 rasterio.features.shapes 矢量化" + # 验证 helper 方法存在 + assert "def _raster_to_boundary_gdf" in text, \ + "应新增 _raster_to_boundary_gdf helper 方法" + # 验证 helper 内有栅格读取 + 矢量化 + GeoDataFrame 返回 + helper_m = re.search( + r"def _raster_to_boundary_gdf\(self,[^\)]*\)[^\n]*:\n(.*?)(?=\n def |\nclass |\Z)", + text, re.DOTALL, + ) + assert helper_m, "找不到 _raster_to_boundary_gdf 方法" + helper = helper_m.group(1) + assert "rasterio.open" in helper, "helper 应调用 rasterio.open 读栅格" + assert "shapes(" in helper, "helper 应调用 rasterio.features.shapes 矢量化" + assert "gpd.GeoDataFrame" in helper, "helper 应返回 gpd.GeoDataFrame" + print("✅ ContentMapper.read_boundary_shapefile 支持 .shp/.dat/.bsq/.tif 多格式分发") + + if __name__ == "__main__": print("=" * 60) print("Smoke test: 彻底修复底层写入路径与掩膜联动") @@ -110,6 +151,7 @@ if __name__ == "__main__": test_step11_panel_calls_pipeline_get_step_output_dir() test_fallback_dir_table_water_mask() test_panel_guard_does_not_overwrite_existing() + test_map_supports_raster_mask_formats() print("=" * 60) print("全部通过 ✅") sys.exit(0)