#!/usr/bin/env python # -*- coding: utf-8 -*- """ 采样点地图生成模块 - 在高光谱假彩色影像上标注采样点 支持功能: 1. 读取高光谱影像并生成假彩色RGB图像 2. 读取CSV文件中的采样点坐标(前两列为纬度、经度) 3. 在影像上标注红色采样点 4. 添加指北针、图例和比例尺 5. 支持地理坐标系转换 """ import numpy as np import pandas as pd import matplotlib.pyplot as plt from pathlib import Path from typing import Optional, Tuple, List, Dict, Union import warnings from matplotlib.patches import FancyArrowPatch import matplotlib.patheffects as path_effects # 性能优化配置 plt.rcParams['agg.path.chunksize'] = 10000 # 提高矢量渲染性能 plt.rcParams['path.simplify'] = True plt.rcParams['path.simplify_threshold'] = 0.1 # 导入GDAL用于影像读写 try: from osgeo import gdal, osr GDAL_AVAILABLE = True except ImportError: GDAL_AVAILABLE = False print("警告: GDAL未安装,地理坐标转换功能可能无法正常工作") class SamplingPointMap: """采样点地图生成类 - 在高光谱假彩色影像上标注采样点""" def __init__(self, output_dir: str = "./point_maps", fast_mode: bool = False): """ 初始化采样点地图生成器 Args: output_dir: 输出目录,用于保存生成的地图 fast_mode: 是否启用快速模式(降低质量换取速度) """ self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.fast_mode = fast_mode # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS'] plt.rcParams['axes.unicode_minus'] = False plt.rcParams['font.size'] = 12 # 性能优化设置 if fast_mode: plt.rcParams['figure.dpi'] = 150 plt.rcParams['savefig.dpi'] = 150 warnings.filterwarnings('ignore', category=UserWarning) else: plt.rcParams['figure.dpi'] = 300 plt.rcParams['savefig.dpi'] = 300 warnings.filterwarnings('ignore') def create_sampling_point_map(self, hyperspectral_path: str, csv_path: str, output_filename: Optional[str] = None, rgb_bands: Optional[List[int]] = None, point_color: str = 'red', point_size: int = 80, point_alpha: float = 0.8, show_north_arrow: bool = True, show_scale_bar: bool = True, show_legend: bool = True, dpi: int = None, downsample: bool = False) -> str: """ 创建采样点地图:在高光谱假彩色影像上标注采样点 Args: hyperspectral_path: 高光谱影像文件路径 (.dat, .bsq, .tif等) csv_path: 采样点CSV文件路径(前两列为纬度、经度) output_filename: 输出文件名(如果为None则自动生成) rgb_bands: 用于RGB合成的三个波段索引 [R, G, B],默认为None自动选择 point_color: 采样点颜色 point_size: 采样点大小 point_alpha: 采样点透明度 show_north_arrow: 是否显示指北针 show_scale_bar: 是否显示比例尺 show_legend: 是否显示图例 dpi: 输出图像分辨率(None时使用fast_mode设置) downsample: 是否对图像进行下采样以加快速度(大影像推荐启用) Returns: 生成的地图文件路径 """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法处理地理坐标转换") print(f"正在生成采样点地图...{' (快速模式)' if self.fast_mode else ''}") # 读取高光谱影像 - 优化:仅读取需要的RGB波段 hyperspectral_img, geotransform, projection, width, height, sample_factor = self._read_hyperspectral( hyperspectral_path, rgb_bands, downsample) # 读取采样点 sampling_points = self._read_sampling_points(csv_path) # 生成假彩色图像 - 应用线性拉伸 rgb_image = self._create_false_color_image(hyperspectral_img) # 将地理坐标转换为像素坐标 - 支持投影系转换和下采样 pixel_coords = self._geo_to_pixel(sampling_points, geotransform, width, height, projection, sample_factor) # 创建地图 if output_filename is None: csv_name = Path(csv_path).stem hs_name = Path(hyperspectral_path).stem output_filename = f"{hs_name}_{csv_name}_sampling_map.png" output_path = self.output_dir / output_filename # 使用更优化的绘图设置 if dpi is None: dpi = 150 if self.fast_mode else 200 self._create_map_visualization( rgb_image, pixel_coords, sampling_points, str(output_path), point_color, point_size, point_alpha, show_north_arrow, show_scale_bar, show_legend, dpi, geotransform, width, height, downsample, projection, sample_factor ) print(f"采样点地图已保存: {output_path}") return str(output_path) def _read_hyperspectral(self, hyperspectral_path: str, rgb_bands: Optional[List[int]] = None, downsample: bool = False) -> Tuple[np.ndarray, tuple, str, int, int]: """优化版:读取高光谱影像 - 仅读取需要的RGB波段""" dataset = gdal.Open(hyperspectral_path) if dataset is None: raise ValueError(f"无法打开高光谱影像: {hyperspectral_path}") width = dataset.RasterXSize height = dataset.RasterYSize band_count = dataset.RasterCount # 确定要读取的波段 - 优先使用指定波长 (650nm, 550nm, 460nm) if rgb_bands is None: if band_count >= 3: try: # 使用find_band_number根据波长查找RGB波段 from src.utils.util import find_band_number rgb_bands = [ find_band_number(650.0, hyperspectral_path), # Red ~650nm find_band_number(550.0, hyperspectral_path), # Green ~550nm find_band_number(460.0, hyperspectral_path) # Blue ~460nm ] print(f" 根据波长选择RGB波段: R={rgb_bands[0]}, G={rgb_bands[1]}, B={rgb_bands[2]}") except Exception as e: print(f" 波长查找失败 ({e}),使用默认索引") # 回退到基于索引的选择 rgb_bands = [min(band_count-1, int(band_count*0.25)), min(band_count-1, int(band_count*0.15)), min(band_count-1, int(band_count*0.05))] else: rgb_bands = [0, 0, 0] # 下采样控制 - 用户反馈下采样读取会导致像素值全为0 if downsample and (width > 2000 or height > 2000): print(f" ⚠ 下采样暂被禁用(会导致像素值全0),使用原始分辨率: {width}x{height}") sample_factor = 1 target_width = width target_height = height else: sample_factor = 1 target_width = width target_height = height # 只读取需要的RGB波段(性能关键优化) rgb_data = [] for band_idx in rgb_bands: band = dataset.GetRasterBand(band_idx + 1) # 直接使用完整分辨率读取,避免下采样导致像素值为0的问题 band_data = band.ReadAsArray().astype(np.float32) rgb_data.append(band_data) # 堆叠为RGB图像 (height, width, 3) if len(rgb_data) == 3: image_array = np.stack(rgb_data, axis=2) else: # 如果只有1个波段,复制为RGB image_array = np.stack([rgb_data[0]]*3, axis=2) geotransform = dataset.GetGeoTransform() projection = dataset.GetProjection() # 释放数据集 dataset = None # 更新尺寸信息 final_width = target_width if sample_factor > 1 else width final_height = target_height if sample_factor > 1 else height print(f" 读取影像: {final_width}x{final_height}x{image_array.shape[2]} (RGB)") if projection: proj_type = "投影坐标系" if "PROJCS" in projection else "地理坐标系" print(f" 影像投影: {proj_type}") if sample_factor > 1: print(f" 下采样因子: {sample_factor}") return image_array, geotransform, projection, final_width, final_height, sample_factor def _read_sampling_points(self, csv_path: str) -> pd.DataFrame: """读取采样点CSV文件""" if not Path(csv_path).exists(): raise FileNotFoundError(f"CSV文件不存在: {csv_path}") df = pd.read_csv(csv_path) # 检查前两列是否为纬度和经度 if len(df.columns) < 2: raise ValueError("CSV文件至少需要两列(纬度、经度)") # 假设前两列是纬度和经度 lat_col = df.columns[0] lon_col = df.columns[1] # 重命名列 df = df.rename(columns={lat_col: 'latitude', lon_col: 'longitude'}) # 确保数值类型 df['latitude'] = pd.to_numeric(df['latitude'], errors='coerce') df['longitude'] = pd.to_numeric(df['longitude'], errors='coerce') # 删除无效的坐标 df = df.dropna(subset=['latitude', 'longitude']) print(f"读取到 {len(df)} 个采样点") return df def _create_false_color_image(self, image_array: np.ndarray, rgb_bands: Optional[List[int]] = None) -> np.ndarray: """创建假彩色RGB图像 - 应用线性拉伸和Gamma校正""" # 由于_read_hyperspectral已返回RGB图像,这里仅进行最终处理 if image_array.shape[2] != 3: # 确保是3通道 if len(image_array.shape) == 2 or image_array.shape[2] == 1: if len(image_array.shape) == 2: image_array = np.stack([image_array]*3, axis=2) else: image_array = np.repeat(image_array, 3, axis=2) print(f" 处理前图像范围: R[{image_array[:,:,0].min():.3f}-{image_array[:,:,0].max():.3f}], " f"G[{image_array[:,:,1].min():.3f}-{image_array[:,:,1].max():.3f}], " f"B[{image_array[:,:,2].min():.3f}-{image_array[:,:,2].max():.3f}]") # 增强型线性拉伸 - 解决图像太暗的问题 def simple_linear_stretch(data, min_percent=1, max_percent=99): """增强对比度的线性拉伸""" valid_data = data[np.isfinite(data)] if len(valid_data) == 0: return np.zeros_like(data, dtype=np.float32) # 计算百分位数,使用更激进的拉伸 (1%-99%) p_low = np.percentile(valid_data, min_percent) p_high = np.percentile(valid_data, max_percent) if p_high - p_low < 1e-8: # 如果数据范围太小,使用最小最大值归一化 data_min = valid_data.min() data_max = valid_data.max() if data_max > data_min: stretched = (data - data_min) / (data_max - data_min) else: stretched = np.zeros_like(data, dtype=np.float32) else: stretched = (data - p_low) / (p_high - p_low) # 允许轻微过饱和以增加对比度 stretched = np.clip(stretched, 0.0, 1.05) stretched = np.clip(stretched, 0.0, 1.0) # 最终确保在[0,1] return stretched # 对每个通道进行拉伸 r_stretched = simple_linear_stretch(image_array[:, :, 0]) g_stretched = simple_linear_stretch(image_array[:, :, 1]) b_stretched = simple_linear_stretch(image_array[:, :, 2]) # 合成为RGB图像 rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2) rgb_image = np.nan_to_num(rgb_image, nan=0.0) # 最终确保范围在[0,1],并轻微增强对比度 rgb_image = np.clip(rgb_image, 0.0, 1.0) # 可选:Gamma校正增加亮度(解决太暗问题) gamma = 1 # <1会增加亮度 rgb_image = np.power(rgb_image, gamma) # 映射到0-255范围(uint8),这样imshow显示效果更好 rgb_image = (rgb_image * 255).astype(np.uint8) print(f" 处理后图像范围: [0-255] (Gamma={gamma})") return rgb_image def _geo_to_pixel(self, sampling_points: pd.DataFrame, geotransform: tuple, width: int, height: int, projection: str = "", sample_factor: int = 1) -> List[Tuple[float, float]]: """ 使用GDAL进行地理坐标到像素坐标的投影变换 - 支持下采样 原始点位坐标格式: 41.66054612 124.2208338 (WGS84地理坐标: 纬度,经度) 高光谱影像通常使用UTM或其他投影坐标系 当图像下采样时,sample_factor > 1,需要相应缩放坐标 """ if geotransform is None or len(sampling_points) == 0: # 如果没有地理变换信息,使用图像中心 return [(width/2, height/2) for _ in range(len(sampling_points))] pixel_coords = [] gt = geotransform # 检查是否需要投影转换 needs_transform = False if projection and ("PROJCS" in projection or "GEOGCS" in projection): needs_transform = True print(f" 检测到影像投影: {projection[:80]}...") # 创建坐标转换对象(WGS84 -> 影像投影) transform = None if needs_transform and GDAL_AVAILABLE: try: # 源坐标系: WGS84 (EPSG:4326) src_srs = osr.SpatialReference() src_srs.ImportFromEPSG(4326) # WGS84 # 目标坐标系: 影像的投影 dst_srs = osr.SpatialReference() dst_srs.ImportFromWkt(projection) # 创建坐标转换 transform = osr.CoordinateTransformation(src_srs, dst_srs) print(" ✓ 已创建WGS84到影像投影的坐标转换") except Exception as e: print(f" ⚠ 坐标转换创建失败: {e},使用简化变换") transform = None for _, row in sampling_points.iterrows(): lon = float(row['longitude']) # 经度 (WGS84) lat = float(row['latitude']) # 纬度 (WGS84) if transform is not None: # 使用GDAL进行投影转换: (经度, 纬度) -> (投影X, 投影Y) try: proj_x, proj_y, _ = transform.TransformPoint(lat, lon) # 再转换为像素坐标 x = (proj_x - gt[0]) / gt[1] y = (proj_y - gt[3]) / gt[5] except Exception as e: # 转换失败时回退到直接计算 x = (lon - gt[0]) / gt[1] y = (lat - gt[3]) / gt[5] else: # 直接使用仿射变换(坐标系一致的情况) x = (lon - gt[0]) / gt[1] y = (lat - gt[3]) / gt[5] # 如果图像进行了下采样,需要相应缩放坐标 if sample_factor > 1: x = x / sample_factor y = y / sample_factor # 限制在图像范围内(使用下采样后的尺寸) x = max(0, min(x, width - 1)) y = max(0, min(y, height - 1)) pixel_coords.append((x, y)) if transform is not None: print(f" ✓ 使用GDAL投影变换处理 {len(pixel_coords)} 个采样点") else: print(f" 使用直接仿射变换处理 {len(pixel_coords)} 个采样点") return pixel_coords def _create_map_visualization(self, rgb_image: np.ndarray, pixel_coords: List[Tuple[float, float]], sampling_points: pd.DataFrame, output_path: str, point_color: str, point_size: int, point_alpha: float, show_north_arrow: bool, show_scale_bar: bool, show_legend: bool, dpi: int, geotransform: tuple, width: int, height: int, downsample: bool = False, projection: str = "", sample_factor: int = 1): """创建地图可视化 - 优化版""" # 使用更小的figure尺寸加快渲染 figsize = (10, 8) if self.fast_mode or downsample else (12, 10) fig, ax = plt.subplots(figsize=figsize, dpi=100 if self.fast_mode else 150) # 显示假彩色图像 - 现在已经是0-255的uint8格式 print(f" 最终图像数据范围: [{rgb_image.min()}, {rgb_image.max()}] (uint8)") ax.imshow(rgb_image, interpolation='nearest' if self.fast_mode else 'bilinear') # 绘制采样点 - 优化:使用scatter代替循环plot if pixel_coords: x_coords = [p[0] for p in pixel_coords] y_coords = [p[1] for p in pixel_coords] ax.scatter(x_coords, y_coords, c=point_color, s=point_size, alpha=point_alpha, edgecolors='white', linewidth=1.5) # 添加指北针 if show_north_arrow: self._add_north_arrow(ax, width, height, position='bottom-left', direction='down') # 添加比例尺 if show_scale_bar and geotransform is not None: self._add_scale_bar(ax, geotransform, width, height) # 添加图例 if show_legend: legend_text = f'采样点 (n={len(sampling_points)})' ax.plot([], [], 'o', color=point_color, markersize=8, label=legend_text) ax.legend(loc='lower right', frameon=True, facecolor='white', edgecolor='gray') # 设置标题和标签 ax.set_title('高光谱影像采样点分布图', fontsize=16, fontweight='bold', pad=20) # 隐藏坐标轴刻度 ax.set_xticks([]) ax.set_yticks([]) # 添加网格 ax.grid(True, alpha=0.2, linestyle='--') plt.tight_layout() # 保存参数 - 避免传递不兼容的参数 save_kwargs = { 'dpi': dpi, 'bbox_inches': 'tight', 'pad_inches': 0.05, 'facecolor': 'white' } # 仅添加matplotlib支持的参数 if self.fast_mode: save_kwargs['dpi'] = min(dpi, 180) # 快速模式降低DPI plt.savefig(output_path, **save_kwargs) plt.close(fig) def _add_north_arrow(self, ax, width: int, height: int, position='top-right', direction='down', size=0.08, color='white', n_color='white', outline_color='black'): """ 添加指北针,可配置位置、方向、大小、颜色。 参数: ax: matplotlib Axes对象 width, height: 图像宽高(用于相对定位) position: 'top-left', 'top-right', 'bottom-left', 'bottom-right' direction: 'up', 'down', 'left', 'right' 箭头指向 size: 箭头长度相对于高度的比例(0.05~0.12) color: 箭头颜色 n_color: 'N' 文字颜色 outline_color: 文字描边颜色 """ # 位置映射(偏移系数) pos_map = { 'top-left': (0.08, 0.88), 'top-right': (0.92, 0.88), 'bottom-left': (0.08, 0.12), 'bottom-right': (0.92, 0.12), } arrow_x_ratio, arrow_y_ratio = pos_map.get(position, (0.92, 0.88)) arrow_x = width * arrow_x_ratio arrow_y = height * arrow_y_ratio # 方向映射(箭头终点偏移) direction_map = { 'up': (0, +size), 'down': (0, -size), 'left': (-size, 0), 'right': (+size, 0), } dx, dy = direction_map.get(direction, (0, -size)) end_x = arrow_x + dx * width # 注意:dx是比例,乘以宽度/高度保持比例一致 end_y = arrow_y + dy * height # 箭头绘制 arrow = FancyArrowPatch((arrow_x, arrow_y), (end_x, end_y), color=color, linewidth=3, arrowstyle='->', mutation_scale=20) ax.add_patch(arrow) # N 文字位置:在箭头尾部或头部?通常放在箭头指向的反方向末端 # 这里放在箭头尾部向外偏移一点(便于阅读) # 偏移系数根据方向决定 offset_scale = 0.02 # 偏移量比例 if direction == 'up': text_x = arrow_x text_y = arrow_y - height * offset_scale # 放在箭头下方 elif direction == 'down': text_x = arrow_x text_y = arrow_y + height * offset_scale # 放在箭头上方 elif direction == 'left': text_x = arrow_x + width * offset_scale text_y = arrow_y else: # right text_x = arrow_x - width * offset_scale text_y = arrow_y ax.text(text_x, text_y, 'N', fontsize=14, fontweight='bold', color=n_color, ha='center', va='center', path_effects=[path_effects.withStroke(linewidth=3, foreground=outline_color)]) def _add_scale_bar(self, ax, geotransform: tuple, width: int, height: int): """添加比例尺""" if geotransform is None: return # 计算图像实际宽度(米) pixel_size_x = abs(geotransform[1]) image_width_meters = width * pixel_size_x # 选择合适的比例尺长度(图像宽度的1/4) scale_length_m = image_width_meters / 4 scale_length_pixels = width / 4 # 找到合适的刻度 scale_options = [1000, 500, 200, 100, 50, 20, 10, 5, 2, 1] scale_meters = next((s for s in scale_options if s <= scale_length_m), 1) scale_pixels = int(scale_meters / pixel_size_x) # 在左下角添加比例尺 bar_x = width * 0.08 bar_y = height * 0.92 # 绘制比例尺线 ax.plot([bar_x, bar_x + scale_pixels], [bar_y, bar_y], color='white', linewidth=4) # 添加刻度线 ax.plot([bar_x, bar_x], [bar_y, bar_y + 8], color='white', linewidth=2) ax.plot([bar_x + scale_pixels, bar_x + scale_pixels], [bar_y, bar_y + 8], color='white', linewidth=2) # 添加文字 ax.text(bar_x + scale_pixels/2, bar_y , f'{scale_meters} m', fontsize=11, ha='center', va='bottom', fontweight='bold', bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1)) def batch_create_maps(self, hyperspectral_path: str, csv_folder: str, output_subdir: str = "sampling_maps", fast_mode: bool = True) -> Dict[str, str]: """ 批量创建采样点地图 Args: hyperspectral_path: 高光谱影像路径 csv_folder: 包含多个CSV文件的文件夹 output_subdir: 输出子目录 Returns: 生成的地图文件路径字典 """ csv_folder_path = Path(csv_folder) if not csv_folder_path.exists(): raise FileNotFoundError(f"CSV文件夹不存在: {csv_folder}") # 创建输出目录 output_dir = self.output_dir / output_subdir output_dir.mkdir(parents=True, exist_ok=True) map_paths = {} # 查找所有CSV文件 csv_files = list(csv_folder_path.glob("*.csv")) print(f"找到 {len(csv_files)} 个CSV文件,开始批量生成采样点地图... (快速模式: {fast_mode})") for csv_file in csv_files: try: output_filename = f"{Path(hyperspectral_path).stem}_{csv_file.stem}_sampling_map.png" map_path = self.create_sampling_point_map( hyperspectral_path=hyperspectral_path, csv_path=str(csv_file), output_filename=output_filename, downsample=True, # 批量模式默认下采样 dpi=120 if fast_mode else 200 ) map_paths[csv_file.name] = map_path print(f"✓ 生成: {csv_file.name}") except Exception as e: print(f"✗ 处理 {csv_file.name} 失败: {e}") print(f"批量生成完成,共生成 {len(map_paths)} 个采样点地图") return map_paths # 测试代码 if __name__ == "__main__": # 示例用法 map_generator = SamplingPointMap(output_dir="./point_maps") # 测试代码已禁用,避免直接运行时出错 map_generator_fast = SamplingPointMap(output_dir="./point_maps", fast_mode=True) map_path = map_generator_fast.create_sampling_point_map( hyperspectral_path=r"D:\BaiduNetdiskDownload\yaobao\result3.bsq", csv_path=r"E:\code\WQ\pipeline_result\work_dir\4_processed_data\processed_data.csv", downsample=True, dpi=150 ) print("测试代码已注释,请通过GUI或手动调用使用。") print("SamplingPointMap类已创建,可以用于生成带采样点的地图。") print("性能优化功能:") print(" - fast_mode=True: 快速模式 (推荐用于预览)") print(" - downsample=True: 对大影像下采样 (推荐用于>2000x2000影像)") print(" - 使用: SamplingPointMap(fast_mode=True).create_sampling_point_map(...)")