637 lines
26 KiB
Python
637 lines
26 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
采样点地图生成模块 - 在高光谱假彩色影像上标注采样点
|
||
|
||
支持功能:
|
||
1. 读取高光谱影像并生成假彩色RGB图像
|
||
2. 读取CSV文件中的采样点坐标(前两列为纬度、经度)
|
||
3. 在影像上标注红色采样点
|
||
4. 添加指北针、图例和比例尺
|
||
5. 支持地理坐标系转换
|
||
"""
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
from pathlib import Path
|
||
from typing import Optional, Tuple, List, Dict, Union
|
||
import warnings
|
||
from matplotlib.patches import FancyArrowPatch
|
||
import matplotlib.patheffects as path_effects
|
||
|
||
# 性能优化配置
|
||
plt.rcParams['agg.path.chunksize'] = 10000 # 提高矢量渲染性能
|
||
plt.rcParams['path.simplify'] = True
|
||
plt.rcParams['path.simplify_threshold'] = 0.1
|
||
|
||
# 导入GDAL用于影像读写
|
||
try:
|
||
from osgeo import gdal, osr
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
print("警告: GDAL未安装,地理坐标转换功能可能无法正常工作")
|
||
|
||
|
||
class SamplingPointMap:
|
||
"""采样点地图生成类 - 在高光谱假彩色影像上标注采样点"""
|
||
|
||
def __init__(self, output_dir: str = "./point_maps", fast_mode: bool = False):
|
||
"""
|
||
初始化采样点地图生成器
|
||
|
||
Args:
|
||
output_dir: 输出目录,用于保存生成的地图
|
||
fast_mode: 是否启用快速模式(降低质量换取速度)
|
||
"""
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
self.fast_mode = fast_mode
|
||
|
||
# 设置中文字体
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
plt.rcParams['font.size'] = 12
|
||
|
||
# 性能优化设置
|
||
if fast_mode:
|
||
plt.rcParams['figure.dpi'] = 150
|
||
plt.rcParams['savefig.dpi'] = 150
|
||
warnings.filterwarnings('ignore', category=UserWarning)
|
||
else:
|
||
plt.rcParams['figure.dpi'] = 300
|
||
plt.rcParams['savefig.dpi'] = 300
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
def create_sampling_point_map(self,
|
||
hyperspectral_path: str,
|
||
csv_path: str,
|
||
output_filename: Optional[str] = None,
|
||
rgb_bands: Optional[List[int]] = None,
|
||
point_color: str = 'red',
|
||
point_size: int = 80,
|
||
point_alpha: float = 0.8,
|
||
show_north_arrow: bool = True,
|
||
show_scale_bar: bool = True,
|
||
show_legend: bool = True,
|
||
dpi: int = None,
|
||
downsample: bool = False) -> str:
|
||
"""
|
||
创建采样点地图:在高光谱假彩色影像上标注采样点
|
||
|
||
Args:
|
||
hyperspectral_path: 高光谱影像文件路径 (.dat, .bsq, .tif等)
|
||
csv_path: 采样点CSV文件路径(前两列为纬度、经度)
|
||
output_filename: 输出文件名(如果为None则自动生成)
|
||
rgb_bands: 用于RGB合成的三个波段索引 [R, G, B],默认为None自动选择
|
||
point_color: 采样点颜色
|
||
point_size: 采样点大小
|
||
point_alpha: 采样点透明度
|
||
show_north_arrow: 是否显示指北针
|
||
show_scale_bar: 是否显示比例尺
|
||
show_legend: 是否显示图例
|
||
dpi: 输出图像分辨率(None时使用fast_mode设置)
|
||
downsample: 是否对图像进行下采样以加快速度(大影像推荐启用)
|
||
|
||
Returns:
|
||
生成的地图文件路径
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法处理地理坐标转换")
|
||
|
||
print(f"正在生成采样点地图...{' (快速模式)' if self.fast_mode else ''}")
|
||
|
||
# 读取高光谱影像 - 优化:仅读取需要的RGB波段
|
||
hyperspectral_img, geotransform, projection, width, height, sample_factor = self._read_hyperspectral(
|
||
hyperspectral_path, rgb_bands, downsample)
|
||
|
||
# 读取采样点
|
||
sampling_points = self._read_sampling_points(csv_path)
|
||
|
||
# 生成假彩色图像 - 应用线性拉伸
|
||
rgb_image = self._create_false_color_image(hyperspectral_img)
|
||
|
||
# 将地理坐标转换为像素坐标 - 支持投影系转换和下采样
|
||
pixel_coords = self._geo_to_pixel(sampling_points, geotransform, width, height, projection, sample_factor)
|
||
|
||
# 创建地图
|
||
if output_filename is None:
|
||
csv_name = Path(csv_path).stem
|
||
hs_name = Path(hyperspectral_path).stem
|
||
output_filename = f"{hs_name}_{csv_name}_sampling_map.png"
|
||
|
||
output_path = self.output_dir / output_filename
|
||
|
||
# 使用更优化的绘图设置
|
||
if dpi is None:
|
||
dpi = 150 if self.fast_mode else 200
|
||
|
||
self._create_map_visualization(
|
||
rgb_image, pixel_coords, sampling_points,
|
||
str(output_path), point_color, point_size, point_alpha,
|
||
show_north_arrow, show_scale_bar, show_legend, dpi,
|
||
geotransform, width, height, downsample, projection, sample_factor
|
||
)
|
||
|
||
print(f"采样点地图已保存: {output_path}")
|
||
return str(output_path)
|
||
|
||
def _read_hyperspectral(self, hyperspectral_path: str,
|
||
rgb_bands: Optional[List[int]] = None,
|
||
downsample: bool = False) -> Tuple[np.ndarray, tuple, str, int, int]:
|
||
"""优化版:读取高光谱影像 - 仅读取需要的RGB波段"""
|
||
dataset = gdal.Open(hyperspectral_path)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开高光谱影像: {hyperspectral_path}")
|
||
|
||
width = dataset.RasterXSize
|
||
height = dataset.RasterYSize
|
||
band_count = dataset.RasterCount
|
||
|
||
# 确定要读取的波段 - 优先使用指定波长 (650nm, 550nm, 460nm)
|
||
if rgb_bands is None:
|
||
if band_count >= 3:
|
||
try:
|
||
# 使用find_band_number根据波长查找RGB波段
|
||
from src.utils.util import find_band_number
|
||
rgb_bands = [
|
||
find_band_number(650.0, hyperspectral_path), # Red ~650nm
|
||
find_band_number(550.0, hyperspectral_path), # Green ~550nm
|
||
find_band_number(460.0, hyperspectral_path) # Blue ~460nm
|
||
]
|
||
print(f" 根据波长选择RGB波段: R={rgb_bands[0]}, G={rgb_bands[1]}, B={rgb_bands[2]}")
|
||
except Exception as e:
|
||
print(f" 波长查找失败 ({e}),使用默认索引")
|
||
# 回退到基于索引的选择
|
||
rgb_bands = [min(band_count-1, int(band_count*0.25)),
|
||
min(band_count-1, int(band_count*0.15)),
|
||
min(band_count-1, int(band_count*0.05))]
|
||
else:
|
||
rgb_bands = [0, 0, 0]
|
||
|
||
# 下采样控制 - 用户反馈下采样读取会导致像素值全为0
|
||
if downsample and (width > 2000 or height > 2000):
|
||
print(f" ⚠ 下采样暂被禁用(会导致像素值全0),使用原始分辨率: {width}x{height}")
|
||
sample_factor = 1
|
||
target_width = width
|
||
target_height = height
|
||
else:
|
||
sample_factor = 1
|
||
target_width = width
|
||
target_height = height
|
||
|
||
# 只读取需要的RGB波段(性能关键优化)
|
||
rgb_data = []
|
||
for band_idx in rgb_bands:
|
||
band = dataset.GetRasterBand(band_idx + 1)
|
||
# 直接使用完整分辨率读取,避免下采样导致像素值为0的问题
|
||
band_data = band.ReadAsArray().astype(np.float32)
|
||
rgb_data.append(band_data)
|
||
|
||
# 堆叠为RGB图像 (height, width, 3)
|
||
if len(rgb_data) == 3:
|
||
image_array = np.stack(rgb_data, axis=2)
|
||
else:
|
||
# 如果只有1个波段,复制为RGB
|
||
image_array = np.stack([rgb_data[0]]*3, axis=2)
|
||
|
||
geotransform = dataset.GetGeoTransform()
|
||
projection = dataset.GetProjection()
|
||
|
||
# 释放数据集
|
||
dataset = None
|
||
|
||
# 更新尺寸信息
|
||
final_width = target_width if sample_factor > 1 else width
|
||
final_height = target_height if sample_factor > 1 else height
|
||
|
||
print(f" 读取影像: {final_width}x{final_height}x{image_array.shape[2]} (RGB)")
|
||
if projection:
|
||
proj_type = "投影坐标系" if "PROJCS" in projection else "地理坐标系"
|
||
print(f" 影像投影: {proj_type}")
|
||
if sample_factor > 1:
|
||
print(f" 下采样因子: {sample_factor}")
|
||
|
||
return image_array, geotransform, projection, final_width, final_height, sample_factor
|
||
|
||
def _read_sampling_points(self, csv_path: str) -> pd.DataFrame:
|
||
"""读取采样点CSV文件"""
|
||
if not Path(csv_path).exists():
|
||
raise FileNotFoundError(f"CSV文件不存在: {csv_path}")
|
||
|
||
df = pd.read_csv(csv_path)
|
||
|
||
# 检查前两列是否为纬度和经度
|
||
if len(df.columns) < 2:
|
||
raise ValueError("CSV文件至少需要两列(纬度、经度)")
|
||
|
||
# 假设前两列是纬度和经度
|
||
lat_col = df.columns[0]
|
||
lon_col = df.columns[1]
|
||
|
||
# 重命名列
|
||
df = df.rename(columns={lat_col: 'latitude', lon_col: 'longitude'})
|
||
|
||
# 确保数值类型
|
||
df['latitude'] = pd.to_numeric(df['latitude'], errors='coerce')
|
||
df['longitude'] = pd.to_numeric(df['longitude'], errors='coerce')
|
||
|
||
# 删除无效的坐标
|
||
df = df.dropna(subset=['latitude', 'longitude'])
|
||
|
||
print(f"读取到 {len(df)} 个采样点")
|
||
return df
|
||
|
||
def _create_false_color_image(self, image_array: np.ndarray,
|
||
rgb_bands: Optional[List[int]] = None) -> np.ndarray:
|
||
"""创建假彩色RGB图像 - 应用线性拉伸和Gamma校正"""
|
||
# 由于_read_hyperspectral已返回RGB图像,这里仅进行最终处理
|
||
if image_array.shape[2] != 3:
|
||
# 确保是3通道
|
||
if len(image_array.shape) == 2 or image_array.shape[2] == 1:
|
||
if len(image_array.shape) == 2:
|
||
image_array = np.stack([image_array]*3, axis=2)
|
||
else:
|
||
image_array = np.repeat(image_array, 3, axis=2)
|
||
|
||
print(f" 处理前图像范围: R[{image_array[:,:,0].min():.3f}-{image_array[:,:,0].max():.3f}], "
|
||
f"G[{image_array[:,:,1].min():.3f}-{image_array[:,:,1].max():.3f}], "
|
||
f"B[{image_array[:,:,2].min():.3f}-{image_array[:,:,2].max():.3f}]")
|
||
|
||
# 增强型线性拉伸 - 解决图像太暗的问题
|
||
def simple_linear_stretch(data, min_percent=1, max_percent=99):
|
||
"""增强对比度的线性拉伸"""
|
||
valid_data = data[np.isfinite(data)]
|
||
if len(valid_data) == 0:
|
||
return np.zeros_like(data, dtype=np.float32)
|
||
|
||
# 计算百分位数,使用更激进的拉伸 (1%-99%)
|
||
p_low = np.percentile(valid_data, min_percent)
|
||
p_high = np.percentile(valid_data, max_percent)
|
||
|
||
if p_high - p_low < 1e-8:
|
||
# 如果数据范围太小,使用最小最大值归一化
|
||
data_min = valid_data.min()
|
||
data_max = valid_data.max()
|
||
if data_max > data_min:
|
||
stretched = (data - data_min) / (data_max - data_min)
|
||
else:
|
||
stretched = np.zeros_like(data, dtype=np.float32)
|
||
else:
|
||
stretched = (data - p_low) / (p_high - p_low)
|
||
|
||
# 允许轻微过饱和以增加对比度
|
||
stretched = np.clip(stretched, 0.0, 1.05)
|
||
stretched = np.clip(stretched, 0.0, 1.0) # 最终确保在[0,1]
|
||
return stretched
|
||
|
||
# 对每个通道进行拉伸
|
||
r_stretched = simple_linear_stretch(image_array[:, :, 0])
|
||
g_stretched = simple_linear_stretch(image_array[:, :, 1])
|
||
b_stretched = simple_linear_stretch(image_array[:, :, 2])
|
||
|
||
# 合成为RGB图像
|
||
rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2)
|
||
rgb_image = np.nan_to_num(rgb_image, nan=0.0)
|
||
|
||
# 最终确保范围在[0,1],并轻微增强对比度
|
||
rgb_image = np.clip(rgb_image, 0.0, 1.0)
|
||
|
||
# 可选:Gamma校正增加亮度(解决太暗问题)
|
||
gamma = 1 # <1会增加亮度
|
||
rgb_image = np.power(rgb_image, gamma)
|
||
|
||
# 映射到0-255范围(uint8),这样imshow显示效果更好
|
||
rgb_image = (rgb_image * 255).astype(np.uint8)
|
||
|
||
print(f" 处理后图像范围: [0-255] (Gamma={gamma})")
|
||
|
||
return rgb_image
|
||
|
||
def _geo_to_pixel(self, sampling_points: pd.DataFrame,
|
||
geotransform: tuple, width: int, height: int,
|
||
projection: str = "", sample_factor: int = 1) -> List[Tuple[float, float]]:
|
||
"""
|
||
使用GDAL进行地理坐标到像素坐标的投影变换 - 支持下采样
|
||
|
||
原始点位坐标格式: 41.66054612 124.2208338 (WGS84地理坐标: 纬度,经度)
|
||
高光谱影像通常使用UTM或其他投影坐标系
|
||
当图像下采样时,sample_factor > 1,需要相应缩放坐标
|
||
"""
|
||
if geotransform is None or len(sampling_points) == 0:
|
||
# 如果没有地理变换信息,使用图像中心
|
||
return [(width/2, height/2) for _ in range(len(sampling_points))]
|
||
|
||
pixel_coords = []
|
||
gt = geotransform
|
||
|
||
# 检查是否需要投影转换
|
||
needs_transform = False
|
||
if projection and ("PROJCS" in projection or "GEOGCS" in projection):
|
||
needs_transform = True
|
||
print(f" 检测到影像投影: {projection[:80]}...")
|
||
|
||
# 创建坐标转换对象(WGS84 -> 影像投影)
|
||
transform = None
|
||
if needs_transform and GDAL_AVAILABLE:
|
||
try:
|
||
# 源坐标系: WGS84 (EPSG:4326)
|
||
src_srs = osr.SpatialReference()
|
||
src_srs.ImportFromEPSG(4326) # WGS84
|
||
|
||
# 目标坐标系: 影像的投影
|
||
dst_srs = osr.SpatialReference()
|
||
dst_srs.ImportFromWkt(projection)
|
||
|
||
# 创建坐标转换
|
||
transform = osr.CoordinateTransformation(src_srs, dst_srs)
|
||
print(" ✓ 已创建WGS84到影像投影的坐标转换")
|
||
except Exception as e:
|
||
print(f" ⚠ 坐标转换创建失败: {e},使用简化变换")
|
||
transform = None
|
||
|
||
for _, row in sampling_points.iterrows():
|
||
lon = float(row['longitude']) # 经度 (WGS84)
|
||
lat = float(row['latitude']) # 纬度 (WGS84)
|
||
|
||
if transform is not None:
|
||
# 使用GDAL进行投影转换: (经度, 纬度) -> (投影X, 投影Y)
|
||
try:
|
||
proj_x, proj_y, _ = transform.TransformPoint(lat, lon)
|
||
# 再转换为像素坐标
|
||
x = (proj_x - gt[0]) / gt[1]
|
||
y = (proj_y - gt[3]) / gt[5]
|
||
except Exception as e:
|
||
# 转换失败时回退到直接计算
|
||
x = (lon - gt[0]) / gt[1]
|
||
y = (lat - gt[3]) / gt[5]
|
||
else:
|
||
# 直接使用仿射变换(坐标系一致的情况)
|
||
x = (lon - gt[0]) / gt[1]
|
||
y = (lat - gt[3]) / gt[5]
|
||
|
||
# 如果图像进行了下采样,需要相应缩放坐标
|
||
if sample_factor > 1:
|
||
x = x / sample_factor
|
||
y = y / sample_factor
|
||
|
||
# 限制在图像范围内(使用下采样后的尺寸)
|
||
x = max(0, min(x, width - 1))
|
||
y = max(0, min(y, height - 1))
|
||
|
||
pixel_coords.append((x, y))
|
||
|
||
if transform is not None:
|
||
print(f" ✓ 使用GDAL投影变换处理 {len(pixel_coords)} 个采样点")
|
||
else:
|
||
print(f" 使用直接仿射变换处理 {len(pixel_coords)} 个采样点")
|
||
|
||
return pixel_coords
|
||
|
||
def _create_map_visualization(self, rgb_image: np.ndarray,
|
||
pixel_coords: List[Tuple[float, float]],
|
||
sampling_points: pd.DataFrame,
|
||
output_path: str,
|
||
point_color: str,
|
||
point_size: int,
|
||
point_alpha: float,
|
||
show_north_arrow: bool,
|
||
show_scale_bar: bool,
|
||
show_legend: bool,
|
||
dpi: int,
|
||
geotransform: tuple,
|
||
width: int,
|
||
height: int,
|
||
downsample: bool = False,
|
||
projection: str = "",
|
||
sample_factor: int = 1):
|
||
"""创建地图可视化 - 优化版"""
|
||
# 使用更小的figure尺寸加快渲染
|
||
figsize = (10, 8) if self.fast_mode or downsample else (12, 10)
|
||
fig, ax = plt.subplots(figsize=figsize, dpi=100 if self.fast_mode else 150)
|
||
|
||
# 显示假彩色图像 - 现在已经是0-255的uint8格式
|
||
print(f" 最终图像数据范围: [{rgb_image.min()}, {rgb_image.max()}] (uint8)")
|
||
ax.imshow(rgb_image, interpolation='nearest' if self.fast_mode else 'bilinear')
|
||
|
||
# 绘制采样点 - 优化:使用scatter代替循环plot
|
||
if pixel_coords:
|
||
x_coords = [p[0] for p in pixel_coords]
|
||
y_coords = [p[1] for p in pixel_coords]
|
||
ax.scatter(x_coords, y_coords, c=point_color, s=point_size,
|
||
alpha=point_alpha, edgecolors='white', linewidth=1.5)
|
||
|
||
# 添加指北针
|
||
if show_north_arrow:
|
||
self._add_north_arrow(ax, width, height, position='bottom-left', direction='down')
|
||
|
||
# 添加比例尺
|
||
if show_scale_bar and geotransform is not None:
|
||
self._add_scale_bar(ax, geotransform, width, height)
|
||
|
||
# 添加图例
|
||
if show_legend:
|
||
legend_text = f'采样点 (n={len(sampling_points)})'
|
||
ax.plot([], [], 'o', color=point_color, markersize=8, label=legend_text)
|
||
ax.legend(loc='lower right', frameon=True, facecolor='white', edgecolor='gray')
|
||
|
||
# 设置标题和标签
|
||
ax.set_title('高光谱影像采样点分布图', fontsize=16, fontweight='bold', pad=20)
|
||
|
||
|
||
# 隐藏坐标轴刻度
|
||
ax.set_xticks([])
|
||
ax.set_yticks([])
|
||
|
||
# 添加网格
|
||
ax.grid(True, alpha=0.2, linestyle='--')
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存参数 - 避免传递不兼容的参数
|
||
save_kwargs = {
|
||
'dpi': dpi,
|
||
'bbox_inches': 'tight',
|
||
'pad_inches': 0.05,
|
||
'facecolor': 'white'
|
||
}
|
||
|
||
# 仅添加matplotlib支持的参数
|
||
if self.fast_mode:
|
||
save_kwargs['dpi'] = min(dpi, 180) # 快速模式降低DPI
|
||
|
||
plt.savefig(output_path, **save_kwargs)
|
||
plt.close(fig)
|
||
|
||
def _add_north_arrow(self, ax, width: int, height: int, position='top-right', direction='down',
|
||
size=0.08, color='white', n_color='white', outline_color='black'):
|
||
"""
|
||
添加指北针,可配置位置、方向、大小、颜色。
|
||
|
||
参数:
|
||
ax: matplotlib Axes对象
|
||
width, height: 图像宽高(用于相对定位)
|
||
position: 'top-left', 'top-right', 'bottom-left', 'bottom-right'
|
||
direction: 'up', 'down', 'left', 'right' 箭头指向
|
||
size: 箭头长度相对于高度的比例(0.05~0.12)
|
||
color: 箭头颜色
|
||
n_color: 'N' 文字颜色
|
||
outline_color: 文字描边颜色
|
||
"""
|
||
# 位置映射(偏移系数)
|
||
pos_map = {
|
||
'top-left': (0.08, 0.88),
|
||
'top-right': (0.92, 0.88),
|
||
'bottom-left': (0.08, 0.12),
|
||
'bottom-right': (0.92, 0.12),
|
||
}
|
||
arrow_x_ratio, arrow_y_ratio = pos_map.get(position, (0.92, 0.88))
|
||
arrow_x = width * arrow_x_ratio
|
||
arrow_y = height * arrow_y_ratio
|
||
|
||
# 方向映射(箭头终点偏移)
|
||
direction_map = {
|
||
'up': (0, +size),
|
||
'down': (0, -size),
|
||
'left': (-size, 0),
|
||
'right': (+size, 0),
|
||
}
|
||
dx, dy = direction_map.get(direction, (0, -size))
|
||
end_x = arrow_x + dx * width # 注意:dx是比例,乘以宽度/高度保持比例一致
|
||
end_y = arrow_y + dy * height
|
||
|
||
# 箭头绘制
|
||
arrow = FancyArrowPatch((arrow_x, arrow_y), (end_x, end_y),
|
||
color=color, linewidth=3,
|
||
arrowstyle='->', mutation_scale=20)
|
||
ax.add_patch(arrow)
|
||
|
||
# N 文字位置:在箭头尾部或头部?通常放在箭头指向的反方向末端
|
||
# 这里放在箭头尾部向外偏移一点(便于阅读)
|
||
# 偏移系数根据方向决定
|
||
offset_scale = 0.02 # 偏移量比例
|
||
if direction == 'up':
|
||
text_x = arrow_x
|
||
text_y = arrow_y - height * offset_scale # 放在箭头下方
|
||
elif direction == 'down':
|
||
text_x = arrow_x
|
||
text_y = arrow_y + height * offset_scale # 放在箭头上方
|
||
elif direction == 'left':
|
||
text_x = arrow_x + width * offset_scale
|
||
text_y = arrow_y
|
||
else: # right
|
||
text_x = arrow_x - width * offset_scale
|
||
text_y = arrow_y
|
||
|
||
ax.text(text_x, text_y, 'N', fontsize=14, fontweight='bold',
|
||
color=n_color, ha='center', va='center',
|
||
path_effects=[path_effects.withStroke(linewidth=3, foreground=outline_color)])
|
||
|
||
def _add_scale_bar(self, ax, geotransform: tuple, width: int, height: int):
|
||
"""添加比例尺"""
|
||
if geotransform is None:
|
||
return
|
||
|
||
# 计算图像实际宽度(米)
|
||
pixel_size_x = abs(geotransform[1])
|
||
image_width_meters = width * pixel_size_x
|
||
|
||
# 选择合适的比例尺长度(图像宽度的1/4)
|
||
scale_length_m = image_width_meters / 4
|
||
scale_length_pixels = width / 4
|
||
|
||
# 找到合适的刻度
|
||
scale_options = [1000, 500, 200, 100, 50, 20, 10, 5, 2, 1]
|
||
scale_meters = next((s for s in scale_options if s <= scale_length_m), 1)
|
||
|
||
scale_pixels = int(scale_meters / pixel_size_x)
|
||
|
||
# 在左下角添加比例尺
|
||
bar_x = width * 0.08
|
||
bar_y = height * 0.92
|
||
|
||
# 绘制比例尺线
|
||
ax.plot([bar_x, bar_x + scale_pixels], [bar_y, bar_y], color='white', linewidth=4)
|
||
|
||
# 添加刻度线
|
||
ax.plot([bar_x, bar_x], [bar_y, bar_y + 8], color='white', linewidth=2)
|
||
ax.plot([bar_x + scale_pixels, bar_x + scale_pixels], [bar_y, bar_y + 8], color='white', linewidth=2)
|
||
|
||
# 添加文字
|
||
ax.text(bar_x + scale_pixels/2, bar_y , f'{scale_meters} m',
|
||
fontsize=11, ha='center', va='bottom', fontweight='bold',
|
||
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
|
||
|
||
def batch_create_maps(self, hyperspectral_path: str,
|
||
csv_folder: str,
|
||
output_subdir: str = "sampling_maps",
|
||
fast_mode: bool = True) -> Dict[str, str]:
|
||
"""
|
||
批量创建采样点地图
|
||
|
||
Args:
|
||
hyperspectral_path: 高光谱影像路径
|
||
csv_folder: 包含多个CSV文件的文件夹
|
||
output_subdir: 输出子目录
|
||
|
||
Returns:
|
||
生成的地图文件路径字典
|
||
"""
|
||
csv_folder_path = Path(csv_folder)
|
||
if not csv_folder_path.exists():
|
||
raise FileNotFoundError(f"CSV文件夹不存在: {csv_folder}")
|
||
|
||
# 创建输出目录
|
||
output_dir = self.output_dir / output_subdir
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
map_paths = {}
|
||
|
||
# 查找所有CSV文件
|
||
csv_files = list(csv_folder_path.glob("*.csv"))
|
||
|
||
print(f"找到 {len(csv_files)} 个CSV文件,开始批量生成采样点地图... (快速模式: {fast_mode})")
|
||
|
||
for csv_file in csv_files:
|
||
try:
|
||
output_filename = f"{Path(hyperspectral_path).stem}_{csv_file.stem}_sampling_map.png"
|
||
map_path = self.create_sampling_point_map(
|
||
hyperspectral_path=hyperspectral_path,
|
||
csv_path=str(csv_file),
|
||
output_filename=output_filename,
|
||
downsample=True, # 批量模式默认下采样
|
||
dpi=120 if fast_mode else 200
|
||
)
|
||
map_paths[csv_file.name] = map_path
|
||
print(f"✓ 生成: {csv_file.name}")
|
||
|
||
except Exception as e:
|
||
print(f"✗ 处理 {csv_file.name} 失败: {e}")
|
||
|
||
print(f"批量生成完成,共生成 {len(map_paths)} 个采样点地图")
|
||
return map_paths
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == "__main__":
|
||
# 示例用法
|
||
map_generator = SamplingPointMap(output_dir="./point_maps")
|
||
|
||
# 测试代码已禁用,避免直接运行时出错
|
||
map_generator_fast = SamplingPointMap(output_dir="./point_maps", fast_mode=True)
|
||
map_path = map_generator_fast.create_sampling_point_map(
|
||
hyperspectral_path=r"D:\BaiduNetdiskDownload\yaobao\result3.bsq",
|
||
csv_path=r"E:\code\WQ\pipeline_result\work_dir\4_processed_data\processed_data.csv",
|
||
downsample=True,
|
||
dpi=150
|
||
)
|
||
print("测试代码已注释,请通过GUI或手动调用使用。")
|
||
|
||
print("SamplingPointMap类已创建,可以用于生成带采样点的地图。")
|
||
print("性能优化功能:")
|
||
print(" - fast_mode=True: 快速模式 (推荐用于预览)")
|
||
print(" - downsample=True: 对大影像下采样 (推荐用于>2000x2000影像)")
|
||
print(" - 使用: SamplingPointMap(fast_mode=True).create_sampling_point_map(...)")
|