Initial commit of WQ_GUI

This commit is contained in:
2026-04-08 15:25:08 +08:00
commit 91e36407ae
302 changed files with 40872 additions and 0 deletions

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,327 @@
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import os
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False
def plot_individual_boxplots(csv_file_path, save_dir="boxplots"):
"""
为每个数据列单独绘制箱型图并保存
参数:
csv_file_path: CSV文件路径
save_dir: 保存图片的目录
"""
try:
# 读取CSV文件
df = pd.read_csv(csv_file_path)
# 获取第五列之后的数据列索引从0开始第五列索引为4
data_columns = df.iloc[:, 4:]
# 检查是否有数据列
if data_columns.empty:
print("错误CSV文件中没有足够的列至少需要5列")
return
# 创建保存目录
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(f"创建目录: {save_dir}")
# 为每个数据列单独绘制箱型图
for column in data_columns.columns:
# 移除空值
clean_data = data_columns[column].dropna()
if len(clean_data) == 0:
print(f"跳过列 '{column}': 没有有效数据")
continue
# 创建新图形
plt.figure(figsize=(8, 6))
# 绘制箱型图
box_plot = plt.boxplot([clean_data], labels=[column], patch_artist=True,
showfliers=False)
# 美化箱型图
box_plot['boxes'][0].set_facecolor('lightblue')
box_plot['boxes'][0].set_alpha(0.7)
# 添加散点
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
edgecolors='black', linewidth=0.5, zorder=3)
# 设置标题和标签
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
plt.xlabel('数据列', fontsize=12)
plt.ylabel('数值', fontsize=12)
# 添加统计信息到图上
stats_text = f'数据点数: {len(clean_data)}\n均值: {clean_data.mean():.2f}\n中位数: {clean_data.median():.2f}\n标准差: {clean_data.std():.2f}'
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
# 添加网格
plt.grid(True, alpha=0.3, linestyle='--')
# 调整布局
plt.tight_layout()
# 保存图片
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
save_path = os.path.join(save_dir, f'{safe_column_name}_boxplot.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"已保存: {save_path}")
# 关闭图形以释放内存
plt.close()
print(f"\n所有箱型图已保存到目录: {save_dir}")
except FileNotFoundError:
print(f"错误:找不到文件 {csv_file_path}")
except Exception as e:
print(f"错误:{str(e)}")
def plot_individual_boxplots_seaborn(csv_file_path, save_dir="boxplots_seaborn"):
"""
使用seaborn为每个数据列单独绘制箱型图并保存
参数:
csv_file_path: CSV文件路径
save_dir: 保存图片的目录
"""
try:
# 读取CSV文件
df = pd.read_csv(csv_file_path)
# 获取第五列之后的数据列
data_columns = df.iloc[:, 4:]
if data_columns.empty:
print("错误CSV文件中没有足够的列至少需要5列")
return
# 创建保存目录
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(f"创建目录: {save_dir}")
# 为每个数据列单独绘制箱型图
for column in data_columns.columns:
# 移除空值
clean_data = data_columns[column].dropna()
if len(clean_data) == 0:
print(f"跳过列 '{column}': 没有有效数据")
continue
# 创建新图形
plt.figure(figsize=(8, 6))
# 创建数据框用于seaborn
plot_data = pd.DataFrame({
'列名': [column] * len(clean_data),
'数值': clean_data
})
# 使用seaborn绘制箱型图和散点
sns.boxplot(data=plot_data, x='列名', y='数值', palette='Set2')
sns.stripplot(data=plot_data, x='列名', y='数值',
color='red', alpha=0.6, size=5, jitter=True)
# 设置标题和标签
plt.title(f'{column} - 箱型图 (Seaborn)', fontsize=14, fontweight='bold')
plt.xlabel('数据列', fontsize=12)
plt.ylabel('数值', fontsize=12)
# 添加统计信息
stats_text = f'数据点数: {len(clean_data)}\n均值: {clean_data.mean():.2f}\n中位数: {clean_data.median():.2f}\n标准差: {clean_data.std():.2f}'
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
# 添加网格
plt.grid(True, alpha=0.3, linestyle='--')
# 调整布局
plt.tight_layout()
# 保存图片
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
save_path = os.path.join(save_dir, f'{safe_column_name}_boxplot_seaborn.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"已保存: {save_path}")
# 关闭图形以释放内存
plt.close()
print(f"\n所有箱型图已保存到目录: {save_dir}")
except Exception as e:
print(f"错误:{str(e)}")
def plot_boxplot_with_scatter(csv_file_path):
"""
读取CSV文件并绘制第五列之后数据列的箱型图同时标注散点
参数:
csv_file_path: CSV文件路径
"""
try:
# 读取CSV文件
df = pd.read_csv(csv_file_path)
# 获取第五列之后的数据列索引从0开始第五列索引为4
data_columns = df.iloc[:, 4:] # 从第五列开始的所有列
# 检查是否有数据列
if data_columns.empty:
print("错误CSV文件中没有足够的列至少需要5列")
return
# 设置图形大小
plt.figure(figsize=(12, 8))
# 准备数据用于绘制箱型图
box_data = []
labels = []
for column in data_columns.columns:
# 移除空值
clean_data = data_columns[column].dropna()
if len(clean_data) > 0:
box_data.append(clean_data)
labels.append(column)
# 绘制箱型图
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True,
showfliers=False) # 不显示异常值点,因为我们要自己绘制散点
# 美化箱型图
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
for patch, color in zip(box_plot['boxes'], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
# 在每个箱型图上添加散点
for i, data in enumerate(box_data):
# 为每个数据点添加一些随机的x轴偏移避免重叠
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
# 绘制散点
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
edgecolors='black', linewidth=0.5, zorder=3)
# 设置标题和标签
plt.title('数据列箱型图(带散点标注)', fontsize=16, fontweight='bold')
plt.xlabel('数据列', fontsize=12)
plt.ylabel('数值', fontsize=12)
# 旋转x轴标签以避免重叠
plt.xticks(rotation=45, ha='right')
# 添加网格
plt.grid(True, alpha=0.3, linestyle='--')
# 调整布局
plt.tight_layout()
# 显示图形
plt.show()
# 打印统计信息
print(f"成功绘制了 {len(labels)} 个数据列的箱型图")
print("数据列名称:", labels)
# 显示每列的基本统计信息
print("\n各列基本统计信息:")
for column in labels:
data = data_columns[column].dropna()
print(f"{column}: 数据点数={len(data)}, 均值={data.mean():.2f}, 中位数={data.median():.2f}")
except FileNotFoundError:
print(f"错误:找不到文件 {csv_file_path}")
except Exception as e:
print(f"错误:{str(e)}")
def plot_boxplot_with_seaborn(csv_file_path):
"""
使用seaborn绘制更美观的箱型图可选方法
参数:
csv_file_path: CSV文件路径
"""
try:
# 读取CSV文件
df = pd.read_csv(csv_file_path)
# 获取第五列之后的数据列
data_columns = df.iloc[:, 4:]
if data_columns.empty:
print("错误CSV文件中没有足够的列至少需要5列")
return
# 将数据转换为长格式用于seaborn
melted_data = pd.melt(data_columns, var_name='列名', value_name='数值')
melted_data = melted_data.dropna() # 移除空值
# 设置图形大小
plt.figure(figsize=(12, 8))
# 使用seaborn绘制箱型图和散点
sns.boxplot(data=melted_data, x='列名', y='数值', palette='Set3')
sns.stripplot(data=melted_data, x='列名', y='数值',
color='red', alpha=0.6, size=4, jitter=True)
# 设置标题和标签
plt.title('数据列箱型图Seaborn版本', fontsize=16, fontweight='bold')
plt.xlabel('数据列', fontsize=12)
plt.ylabel('数值', fontsize=12)
# 旋转x轴标签
plt.xticks(rotation=45, ha='right')
# 添加网格
plt.grid(True, alpha=0.3, linestyle='--')
# 调整布局
plt.tight_layout()
# 显示图形
plt.show()
except Exception as e:
print(f"错误:{str(e)}")
# 主程序
if __name__ == "__main__":
# 请修改为您的CSV文件路径
csv_file_path = r"E:\code\WQ\yaobao925\output.csv" # 替换为您的CSV文件路径
print("请选择绘图方法:")
print("1. 使用matplotlib绘制所有列在一张图")
print("2. 使用seaborn绘制所有列在一张图")
print("3. 分别绘制每列并保存matplotlib版本")
print("4. 分别绘制每列并保存seaborn版本")
choice = input("请输入选择1-4").strip()
if choice == "1":
plot_boxplot_with_scatter(csv_file_path)
elif choice == "2":
plot_boxplot_with_seaborn(csv_file_path)
elif choice == "3":
plot_individual_boxplots(csv_file_path)
elif choice == "4":
plot_individual_boxplots_seaborn(csv_file_path)
else:
print("默认使用分别绘制并保存seaborn版本...")
plot_individual_boxplots_seaborn(csv_file_path)

View File

@ -0,0 +1,545 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
飞行轨迹可视化模块 - 在多架次GPS数据上绘制飞行轨迹
支持功能:
1. 读取多个.gps文件每个文件代表一个架次
2. 在高光谱假彩色影像上绘制飞行轨迹
3. 不同架次使用不同颜色
4. 图例显示架次起始到结束的时间段
5. 添加指北针、比例尺
6. 保存为PNG图像
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Optional, Tuple, List, Dict, Union
import warnings
from matplotlib.patches import FancyArrowPatch
import matplotlib.patheffects as path_effects
from datetime import datetime
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
# 性能优化配置
plt.rcParams['agg.path.chunksize'] = 10000
plt.rcParams['path.simplify'] = True
plt.rcParams['path.simplify_threshold'] = 0.1
# 导入GDAL用于影像读写
try:
from osgeo import gdal, osr
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
print("警告: GDAL未安装地理坐标转换功能可能无法正常工作")
class FlightPathVisualizer:
"""飞行轨迹可视化类 - 在高光谱假彩色影像上绘制多架次飞行轨迹"""
# 预定义颜色方案(不同架次使用不同颜色)
FLIGHT_COLORS = [
'#FF0000', # 红色
'#00FF00', # 绿色
'#0000FF', # 蓝色
'#FF00FF', # 紫色
'#00FFFF', # 青色
'#FFFF00', # 黄色
'#FF8000', # 橙色
'#8000FF', # 紫罗兰
'#0080FF', # 天蓝
'#FF0080', # 粉红
]
def __init__(self, output_dir: str = "./flight_paths"):
"""
初始化飞行轨迹可视化器
Args:
output_dir: 输出目录,用于保存生成的轨迹图
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 12
warnings.filterwarnings('ignore')
def create_flight_path_map(self,
gps_folder: str,
hyperspectral_path: str,
output_filename: Optional[str] = None,
rgb_bands: Optional[List[int]] = None,
line_width: int = 2,
show_north_arrow: bool = True,
show_scale_bar: bool = True,
dpi: int = 300) -> str:
"""
创建飞行轨迹地图:在高光谱假彩色影像上绘制多架次飞行轨迹
Args:
gps_folder: GPS文件夹路径包含多个.gps文件
hyperspectral_path: 高光谱影像文件路径 (.dat, .bsq, .tif等)
output_filename: 输出文件名如果为None则自动生成
rgb_bands: 用于RGB合成的三个波段索引 [R, G, B]默认为None自动选择(650,550,460nm)
line_width: 轨迹线宽
show_north_arrow: 是否显示指北针
show_scale_bar: 是否显示比例尺
dpi: 输出图像分辨率
Returns:
生成的地图文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法处理地理坐标转换")
print(f"正在生成飞行轨迹地图...")
# 读取高光谱影像
hyperspectral_img, geotransform, projection, width, height = self._read_hyperspectral(
hyperspectral_path, rgb_bands)
# 读取所有GPS文件
flight_data = self._read_gps_files(gps_folder)
if not flight_data:
raise ValueError(f"未在 {gps_folder} 中找到有效的.gps文件")
# 将GPS坐标转换为像素坐标
flight_pixels = self._convert_flights_to_pixels(
flight_data, geotransform, width, height, projection)
# 创建地图
if output_filename is None:
folder_name = Path(gps_folder).name
hs_name = Path(hyperspectral_path).stem
output_filename = f"{hs_name}_{folder_name}_flight_paths.png"
output_path = self.output_dir / output_filename
self._create_map_visualization(
hyperspectral_img, flight_pixels, flight_data,
str(output_path), line_width,
show_north_arrow, show_scale_bar, dpi,
geotransform, width, height
)
print(f"飞行轨迹地图已保存: {output_path}")
return str(output_path)
def _read_hyperspectral(self, hyperspectral_path: str,
rgb_bands: Optional[List[int]] = None) -> Tuple[np.ndarray, tuple, str, int, int]:
"""读取高光谱影像 - 使用650/550/460nm波长"""
dataset = gdal.Open(hyperspectral_path)
if dataset is None:
raise ValueError(f"无法打开高光谱影像: {hyperspectral_path}")
width = dataset.RasterXSize
height = dataset.RasterYSize
band_count = dataset.RasterCount
# 确定要读取的波段 - 使用指定波长 650, 550, 460nm
if rgb_bands is None:
if band_count >= 3:
try:
from src.utils.util import find_band_number
rgb_bands = [
find_band_number(650.0, hyperspectral_path), # Red ~650nm
find_band_number(550.0, hyperspectral_path), # Green ~550nm
find_band_number(460.0, hyperspectral_path) # Blue ~460nm
]
print(f" 根据波长选择RGB波段: R={rgb_bands[0]}, G={rgb_bands[1]}, B={rgb_bands[2]}")
except Exception as e:
print(f" 波长查找失败 ({e}),使用默认索引")
rgb_bands = [min(band_count-1, int(band_count*0.25)),
min(band_count-1, int(band_count*0.15)),
min(band_count-1, int(band_count*0.05))]
else:
rgb_bands = [0, 0, 0]
# 读取RGB波段
rgb_data = []
for band_idx in rgb_bands:
band = dataset.GetRasterBand(band_idx + 1)
band_data = band.ReadAsArray().astype(np.float32)
rgb_data.append(band_data)
# 堆叠为RGB图像
if len(rgb_data) == 3:
image_array = np.stack(rgb_data, axis=2)
else:
image_array = np.stack([rgb_data[0]]*3, axis=2)
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
dataset = None
print(f" 读取影像: {width}x{height}x{image_array.shape[2]} (RGB)")
if projection:
proj_type = "投影坐标系" if "PROJCS" in projection else "地理坐标系"
print(f" 影像投影: {proj_type}")
return image_array, geotransform, projection, width, height
def _read_gps_files(self, gps_folder: str) -> Dict[str, pd.DataFrame]:
"""
读取文件夹中的所有.gps文件
文件格式: 日期、时间、三个姿态角、经度、纬度、高程
列: [date, time, pitch, roll, yaw, longitude, latitude, altitude]
"""
gps_folder_path = Path(gps_folder)
if not gps_folder_path.exists():
raise FileNotFoundError(f"GPS文件夹不存在: {gps_folder}")
gps_files = list(gps_folder_path.glob("*.gps"))
if not gps_files:
print(f"警告: 在 {gps_folder} 中未找到.gps文件")
return {}
print(f"找到 {len(gps_files)} 个GPS文件")
flight_data = {}
for gps_file in gps_files:
try:
# 读取GPS文件制表符分隔
df = pd.read_csv(gps_file, sep='\t', header=None,
names=['date', 'time', 'pitch', 'roll', 'yaw',
'longitude', 'latitude', 'altitude'])
# 确保数值类型
df['longitude'] = pd.to_numeric(df['longitude'], errors='coerce')
df['latitude'] = pd.to_numeric(df['latitude'], errors='coerce')
df['altitude'] = pd.to_numeric(df['altitude'], errors='coerce')
# 删除无效坐标
df = df.dropna(subset=['longitude', 'latitude'])
if len(df) > 0:
flight_data[gps_file.stem] = df
print(f" ✓ 读取 {gps_file.name}: {len(df)} 个轨迹点")
# 显示时间范围
start_time = df.iloc[0]['time']
end_time = df.iloc[-1]['time']
print(f" 时间范围: {start_time} - {end_time}")
else:
print(f"{gps_file.name}: 无有效数据")
except Exception as e:
print(f" ✗ 读取 {gps_file.name} 失败: {e}")
return flight_data
def _convert_flights_to_pixels(self, flight_data: Dict[str, pd.DataFrame],
geotransform: tuple, width: int, height: int,
projection: str = "") -> Dict[str, List[Tuple[float, float]]]:
"""将所有飞行轨迹的地理坐标转换为像素坐标"""
if geotransform is None:
print("警告: 无地理变换信息,无法转换坐标")
return {}
gt = geotransform
flight_pixels = {}
# 检查是否需要投影转换
needs_transform = projection and ("PROJCS" in projection or "GEOGCS" in projection)
transform = None
if needs_transform and GDAL_AVAILABLE:
try:
src_srs = osr.SpatialReference()
src_srs.ImportFromEPSG(4326) # WGS84
dst_srs = osr.SpatialReference()
dst_srs.ImportFromWkt(projection)
transform = osr.CoordinateTransformation(src_srs, dst_srs)
print(" ✓ 已创建WGS84到影像投影的坐标转换")
except Exception as e:
print(f" ⚠ 坐标转换创建失败: {e}")
for flight_name, df in flight_data.items():
pixel_coords = []
for _, row in df.iterrows():
lon = float(row['longitude'])
lat = float(row['latitude'])
if transform is not None:
try:
proj_x, proj_y, _ = transform.TransformPoint(lat, lon)
x = (proj_x - gt[0]) / gt[1]
y = (proj_y - gt[3]) / gt[5]
except:
x = (lon - gt[0]) / gt[1]
y = (lat - gt[3]) / gt[5]
else:
x = (lon - gt[0]) / gt[1]
y = (lat - gt[3]) / gt[5]
# 限制在图像范围内
x = max(0, min(x, width - 1))
y = max(0, min(y, height - 1))
pixel_coords.append((x, y))
flight_pixels[flight_name] = pixel_coords
print(f" 已转换 {len(flight_pixels)} 个架次的坐标")
return flight_pixels
def _create_false_color_image(self, image_array: np.ndarray) -> np.ndarray:
"""创建假彩色RGB图像 - 应用线性拉伸"""
if image_array.shape[2] != 3:
if len(image_array.shape) == 2 or image_array.shape[2] == 1:
if len(image_array.shape) == 2:
image_array = np.stack([image_array]*3, axis=2)
else:
image_array = np.repeat(image_array, 3, axis=2)
def simple_linear_stretch(data, min_percent=1, max_percent=99):
valid_data = data[np.isfinite(data)]
if len(valid_data) == 0:
return np.zeros_like(data, dtype=np.float32)
p_low = np.percentile(valid_data, min_percent)
p_high = np.percentile(valid_data, max_percent)
if p_high - p_low < 1e-8:
data_min = valid_data.min()
data_max = valid_data.max()
if data_max > data_min:
stretched = (data - data_min) / (data_max - data_min)
else:
stretched = np.zeros_like(data, dtype=np.float32)
else:
stretched = (data - p_low) / (p_high - p_low)
stretched = np.clip(stretched, 0.0, 1.0)
return stretched
r_stretched = simple_linear_stretch(image_array[:, :, 0])
g_stretched = simple_linear_stretch(image_array[:, :, 1])
b_stretched = simple_linear_stretch(image_array[:, :, 2])
rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2)
rgb_image = np.nan_to_num(rgb_image, nan=0.0)
rgb_image = np.clip(rgb_image, 0.0, 1.0)
# Gamma校正增加亮度
gamma = 0.85
rgb_image = np.power(rgb_image, gamma)
# 映射到0-255
rgb_image = (rgb_image * 255).astype(np.uint8)
return rgb_image
def _create_map_visualization(self, image_array: np.ndarray,
flight_pixels: Dict[str, List[Tuple[float, float]]],
flight_data: Dict[str, pd.DataFrame],
output_path: str,
line_width: int,
show_north_arrow: bool,
show_scale_bar: bool,
dpi: int,
geotransform: tuple,
width: int,
height: int):
"""创建地图可视化"""
figsize = (14, 10)
fig, ax = plt.subplots(figsize=figsize, dpi=150)
# 处理背景图像
rgb_image = self._create_false_color_image(image_array)
ax.imshow(rgb_image, interpolation='bilinear')
# 绘制飞行轨迹 - 不同架次不同颜色
legend_elements = []
for idx, (flight_name, pixel_coords) in enumerate(flight_pixels.items()):
if len(pixel_coords) < 2:
continue
# 选择颜色
color = self.FLIGHT_COLORS[idx % len(self.FLIGHT_COLORS)]
# 提取x,y坐标
x_coords = [p[0] for p in pixel_coords]
y_coords = [p[1] for p in pixel_coords]
# 绘制轨迹线
ax.plot(x_coords, y_coords, color=color, linewidth=line_width,
alpha=0.8, solid_capstyle='round')
# 标记起点和终点
ax.plot(x_coords[0], y_coords[0], 'o', color=color, markersize=8,
markeredgecolor='white', markeredgewidth=1)
ax.plot(x_coords[-1], y_coords[-1], 's', color=color, markersize=8,
markeredgecolor='white', markeredgewidth=1)
# 获取时间范围用于图例
df = flight_data[flight_name]
start_time = df.iloc[0]['time']
end_time = df.iloc[-1]['time']
# 创建图例元素
legend_label = f"{flight_name}: {start_time} - {end_time}"
legend_elements.append(
mpatches.Patch(color=color, label=legend_label)
)
# 添加图例
if legend_elements:
ax.legend(handles=legend_elements, loc='lower right',
frameon=True, facecolor='white', edgecolor='gray',
fontsize=9, title='飞行轨迹 (起点→终点)')
# 添加指北针
if show_north_arrow:
self._add_north_arrow(ax, width, height)
# 添加比例尺
if show_scale_bar and geotransform is not None:
self._add_scale_bar(ax, geotransform, width, height)
# 设置标题
ax.set_title('多架次飞行轨迹图', fontsize=16, fontweight='bold', pad=20)
# 隐藏坐标轴刻度
ax.set_xticks([])
ax.set_yticks([])
# 添加网格
ax.grid(True, alpha=0.2, linestyle='--')
plt.tight_layout()
plt.savefig(output_path, dpi=dpi, bbox_inches='tight', pad_inches=0.1, facecolor='white')
plt.close(fig)
def _add_north_arrow(self, ax, width: int, height: int):
"""添加指北针"""
arrow_x = width * 0.92
arrow_y = height * 0.88
arrow = FancyArrowPatch((arrow_x, arrow_y), (arrow_x, arrow_y - height*0.08),
color='black', linewidth=3, arrowstyle='->', mutation_scale=20)
ax.add_patch(arrow)
ax.text(arrow_x, arrow_y - height*0.1, 'N', fontsize=14, fontweight='bold',
color='black', ha='center', va='center',
path_effects=[path_effects.withStroke(linewidth=3, foreground='white')])
def _add_scale_bar(self, ax, geotransform: tuple, width: int, height: int):
"""添加比例尺"""
if geotransform is None:
return
pixel_size_x = abs(geotransform[1])
image_width_meters = width * pixel_size_x
scale_length_m = image_width_meters / 4
scale_options = [1000, 500, 200, 100, 50, 20, 10, 5, 2, 1]
scale_meters = next((s for s in scale_options if s <= scale_length_m), 1)
scale_pixels = int(scale_meters / pixel_size_x)
bar_x = width * 0.08
bar_y = height * 0.92
ax.plot([bar_x, bar_x + scale_pixels], [bar_y, bar_y], color='black', linewidth=4)
ax.plot([bar_x, bar_x], [bar_y, bar_y + 8], color='black', linewidth=2)
ax.plot([bar_x + scale_pixels, bar_x + scale_pixels], [bar_y, bar_y + 8], color='black', linewidth=2)
ax.text(bar_x + scale_pixels/2, bar_y + 15, f'{scale_meters} m',
fontsize=11, ha='center', va='bottom', fontweight='bold',
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
def batch_create_maps(self, gps_folder: str,
hyperspectral_folder: str,
output_subdir: str = "flight_paths") -> Dict[str, str]:
"""
批量创建飞行轨迹地图
Args:
gps_folder: 包含多个子文件夹(每个子文件夹包含.gps文件的文件夹
hyperspectral_folder: 包含高光谱影像的文件夹
output_subdir: 输出子目录
Returns:
生成的地图文件路径字典
"""
gps_folder_path = Path(gps_folder)
hs_folder_path = Path(hyperspectral_folder)
if not gps_folder_path.exists():
raise FileNotFoundError(f"GPS文件夹不存在: {gps_folder}")
if not hs_folder_path.exists():
raise FileNotFoundError(f"高光谱文件夹不存在: {hyperspectral_folder}")
output_dir = self.output_dir / output_subdir
output_dir.mkdir(parents=True, exist_ok=True)
map_paths = {}
# 查找所有包含.gps文件的子文件夹
gps_subfolders = [d for d in gps_folder_path.iterdir() if d.is_dir() and list(d.glob("*.gps"))]
# 查找高光谱影像
hs_files = []
for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']:
hs_files.extend(list(hs_folder_path.glob(ext)))
if not hs_files:
print(f"警告: 在 {hyperspectral_folder} 中未找到高光谱影像")
return map_paths
print(f"找到 {len(gps_subfolders)} 个GPS子文件夹和 {len(hs_files)} 个高光谱影像")
# 简单匹配使用第一个高光谱影像与所有GPS文件夹组合
hs_file = hs_files[0]
for gps_subfolder in gps_subfolders:
try:
output_filename = f"{hs_file.stem}_{gps_subfolder.name}_flight_paths.png"
map_path = self.create_flight_path_map(
gps_folder=str(gps_subfolder),
hyperspectral_path=str(hs_file),
output_filename=output_filename,
dpi=200
)
map_paths[gps_subfolder.name] = map_path
print(f"✓ 生成: {gps_subfolder.name}")
except Exception as e:
print(f"✗ 处理 {gps_subfolder.name} 失败: {e}")
print(f"批量生成完成,共生成 {len(map_paths)} 个飞行轨迹图")
return map_paths
# 测试代码
if __name__ == "__main__":
print("FlightPathVisualizer类已创建")
print("用法示例:")
print(" visualizer = FlightPathVisualizer(output_dir='./flight_maps')")
print(" map_path = visualizer.create_flight_path_map(")
print(" gps_folder='./gps_data',")
print(" hyperspectral_path='./hyperspectral.dat'")
print(" )")
visualizer = FlightPathVisualizer(output_dir=r"E:\code\WQ\pipeline_result\work_dir\9_visualization\flight_maps")
# 生成飞行轨迹图
map_path = visualizer.create_flight_path_map(
gps_folder=r"D:\BaiduNetdiskDownload\20250902\gps", # GPS文件夹路径
hyperspectral_path=r"E:\code\WQ\pipeline_result\work_dir\3_deglint\deglint_goodman.bsq", # 高光谱影像路径
output_filename="flight_paths.png",
line_width=2,
dpi=300
)

2186
src/postprocessing/map.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,184 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False
def load_and_plot_spectrum_by_parameters():
"""
加载数据并为每个水质参数绘制光谱曲线图
"""
try:
# 数据文件路径
data_file = Path(r"E:\code\WQ\yaobao925\spectral.csv")
if not data_file.exists():
print(f"错误:数据文件不存在 - {data_file}")
return
# 读取数据
print("正在加载数据...")
data = pd.read_csv(data_file)
print(f"数据形状: {data.shape}")
print(f"列名: {list(data.columns[:15])}...") # 显示前15个列名
# 找到光谱数据的起始列(通常是数字列名)
spectrum_start_idx = None
for i, col in enumerate(data.columns):
try:
float(col)
spectrum_start_idx = i
break
except ValueError:
continue
if spectrum_start_idx is None:
print("错误:未找到光谱数据列")
return
print(f"光谱数据从第 {spectrum_start_idx + 1} 列开始")
# 分离水质参数和光谱数据
water_quality_data = data.iloc[:, :spectrum_start_idx]
spectrum_data = data.iloc[:, spectrum_start_idx:]
# 获取波长信息
try:
# 尝试直接转换为浮点数
wavelengths = spectrum_data.columns.astype(float)
except ValueError:
# 如果包含字母,提取数字部分
import re
wavelengths = []
for col in spectrum_data.columns:
# 提取数字部分
numbers = re.findall(r'\d+\.?\d*', str(col))
if numbers:
wavelengths.append(float(numbers[0]))
else:
# 如果没有数字,使用列索引
wavelengths.append(float(len(wavelengths)))
wavelengths = np.array(wavelengths)
print(f"波长范围: {wavelengths.min():.1f} - {wavelengths.max():.1f} nm")
print(f"光谱数据形状: {spectrum_data.shape}")
print(f"水质参数: {list(water_quality_data.columns)}")
# 过滤波长范围到374-1011nm
wavelength_mask = (wavelengths >= 374) & (wavelengths <= 1011)
filtered_wavelengths = wavelengths[wavelength_mask]
filtered_spectrum_data = spectrum_data.iloc[:, wavelength_mask]
print(f"过滤后波长范围: {filtered_wavelengths.min():.1f} - {filtered_wavelengths.max():.1f} nm")
print(f"过滤后光谱数据形状: {filtered_spectrum_data.shape}")
# 创建输出目录
output_dir = Path(r'E:\code\WQ\yaobao925\plot')
output_dir.mkdir(exist_ok=True)
# 为每个水质参数绘制光谱图
for param_idx, parameter_name in enumerate(water_quality_data.columns):
print(f"\n[{param_idx+1}/{len(water_quality_data.columns)}] 处理参数: {parameter_name}")
# 获取当前参数的数据
parameter_values = water_quality_data[parameter_name]
# 过滤掉空值
valid_mask = ~parameter_values.isna()
if valid_mask.sum() == 0:
print(f"参数 '{parameter_name}' 没有有效数据,跳过")
continue
valid_param_values = parameter_values[valid_mask]
valid_spectrum_data = filtered_spectrum_data[valid_mask]
print(f"有效样本数: {len(valid_param_values)}")
# 创建图形和轴
fig, ax = plt.subplots(figsize=(12, 8))
# 归一化参数值到[0,1]范围,用于颜色映射
param_min = valid_param_values.min()
param_max = valid_param_values.max()
if param_max == param_min:
# 如果所有值相同,使用中等颜色
normalized_values = np.full(len(valid_param_values), 0.5)
else:
normalized_values = ((valid_param_values - param_min) / (param_max - param_min)).values
# 创建蓝红颜色映射(蓝色到红色)
colormap = plt.cm.coolwarm # 蓝色(低值)到红色(高值)
# 绘制每条光谱曲线
for i, (idx, spectrum) in enumerate(valid_spectrum_data.iterrows()):
# 处理光谱数据中的空值
spectrum_values = pd.Series(spectrum.values).fillna(0).values
# 根据参数值确定颜色
color = colormap(normalized_values[i])
alpha = 0.6 if len(valid_param_values) > 50 else 0.8 # 样本多时降低透明度
ax.plot(filtered_wavelengths, spectrum_values, color=color, alpha=alpha, linewidth=0.8)
# 设置图形属性
ax.set_xlabel('波长 (nm)', fontsize=12)
ax.set_ylabel('光谱强度', fontsize=12)
ax.set_title(f'{parameter_name} 光谱曲线图\n参数范围: {param_min:.4f} - {param_max:.4f}',
fontsize=14, fontweight='bold')
# 设置坐标轴范围限制在374-1011nm
ax.set_xlim(374, 1011)
# 添加网格
ax.grid(True, alpha=0.3)
# 创建颜色条
sm = plt.cm.ScalarMappable(cmap=colormap,
norm=plt.Normalize(vmin=param_min, vmax=param_max))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, shrink=0.8)
cbar.set_label(f'{parameter_name} 数值', rotation=270, labelpad=20, fontsize=12)
# 添加统计信息文本框
stats_text = f'样本数: {len(valid_param_values)}\n'
stats_text += f'均值: {valid_param_values.mean():.4f}\n'
stats_text += f'标准差: {valid_param_values.std():.4f}'
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
verticalalignment='top', bbox=dict(boxstyle='round',
facecolor='wheat', alpha=0.8), fontsize=10)
# 优化布局
plt.tight_layout()
# 保存图片
# 清理参数名称,用于文件名
safe_param_name = "".join(c for c in parameter_name if c.isalnum() or c in ('-', '_', '.')).rstrip()
output_file = output_dir / f"{safe_param_name}_spectrum.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
plt.close() # 关闭图形释放内存
print(f"图片已保存到: {output_file}")
print(f"\n{'='*80}")
print(f"所有光谱图绘制完成!")
print(f"输出目录: {output_dir}")
print(f"{'='*80}")
except Exception as e:
print(f"处理过程中出现错误: {str(e)}")
import traceback
traceback.print_exc()
def main():
"""主函数"""
load_and_plot_spectrum_by_parameters()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,636 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
采样点地图生成模块 - 在高光谱假彩色影像上标注采样点
支持功能:
1. 读取高光谱影像并生成假彩色RGB图像
2. 读取CSV文件中的采样点坐标前两列为纬度、经度
3. 在影像上标注红色采样点
4. 添加指北针、图例和比例尺
5. 支持地理坐标系转换
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Optional, Tuple, List, Dict, Union
import warnings
from matplotlib.patches import FancyArrowPatch
import matplotlib.patheffects as path_effects
# 性能优化配置
plt.rcParams['agg.path.chunksize'] = 10000 # 提高矢量渲染性能
plt.rcParams['path.simplify'] = True
plt.rcParams['path.simplify_threshold'] = 0.1
# 导入GDAL用于影像读写
try:
from osgeo import gdal, osr
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
print("警告: GDAL未安装地理坐标转换功能可能无法正常工作")
class SamplingPointMap:
"""采样点地图生成类 - 在高光谱假彩色影像上标注采样点"""
def __init__(self, output_dir: str = "./point_maps", fast_mode: bool = False):
"""
初始化采样点地图生成器
Args:
output_dir: 输出目录,用于保存生成的地图
fast_mode: 是否启用快速模式(降低质量换取速度)
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.fast_mode = fast_mode
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 12
# 性能优化设置
if fast_mode:
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 150
warnings.filterwarnings('ignore', category=UserWarning)
else:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
warnings.filterwarnings('ignore')
def create_sampling_point_map(self,
hyperspectral_path: str,
csv_path: str,
output_filename: Optional[str] = None,
rgb_bands: Optional[List[int]] = None,
point_color: str = 'red',
point_size: int = 80,
point_alpha: float = 0.8,
show_north_arrow: bool = True,
show_scale_bar: bool = True,
show_legend: bool = True,
dpi: int = None,
downsample: bool = False) -> str:
"""
创建采样点地图:在高光谱假彩色影像上标注采样点
Args:
hyperspectral_path: 高光谱影像文件路径 (.dat, .bsq, .tif等)
csv_path: 采样点CSV文件路径前两列为纬度、经度
output_filename: 输出文件名如果为None则自动生成
rgb_bands: 用于RGB合成的三个波段索引 [R, G, B]默认为None自动选择
point_color: 采样点颜色
point_size: 采样点大小
point_alpha: 采样点透明度
show_north_arrow: 是否显示指北针
show_scale_bar: 是否显示比例尺
show_legend: 是否显示图例
dpi: 输出图像分辨率None时使用fast_mode设置
downsample: 是否对图像进行下采样以加快速度(大影像推荐启用)
Returns:
生成的地图文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法处理地理坐标转换")
print(f"正在生成采样点地图...{' (快速模式)' if self.fast_mode else ''}")
# 读取高光谱影像 - 优化仅读取需要的RGB波段
hyperspectral_img, geotransform, projection, width, height, sample_factor = self._read_hyperspectral(
hyperspectral_path, rgb_bands, downsample)
# 读取采样点
sampling_points = self._read_sampling_points(csv_path)
# 生成假彩色图像 - 应用线性拉伸
rgb_image = self._create_false_color_image(hyperspectral_img)
# 将地理坐标转换为像素坐标 - 支持投影系转换和下采样
pixel_coords = self._geo_to_pixel(sampling_points, geotransform, width, height, projection, sample_factor)
# 创建地图
if output_filename is None:
csv_name = Path(csv_path).stem
hs_name = Path(hyperspectral_path).stem
output_filename = f"{hs_name}_{csv_name}_sampling_map.png"
output_path = self.output_dir / output_filename
# 使用更优化的绘图设置
if dpi is None:
dpi = 150 if self.fast_mode else 200
self._create_map_visualization(
rgb_image, pixel_coords, sampling_points,
str(output_path), point_color, point_size, point_alpha,
show_north_arrow, show_scale_bar, show_legend, dpi,
geotransform, width, height, downsample, projection, sample_factor
)
print(f"采样点地图已保存: {output_path}")
return str(output_path)
def _read_hyperspectral(self, hyperspectral_path: str,
rgb_bands: Optional[List[int]] = None,
downsample: bool = False) -> Tuple[np.ndarray, tuple, str, int, int]:
"""优化版:读取高光谱影像 - 仅读取需要的RGB波段"""
dataset = gdal.Open(hyperspectral_path)
if dataset is None:
raise ValueError(f"无法打开高光谱影像: {hyperspectral_path}")
width = dataset.RasterXSize
height = dataset.RasterYSize
band_count = dataset.RasterCount
# 确定要读取的波段 - 优先使用指定波长 (650nm, 550nm, 460nm)
if rgb_bands is None:
if band_count >= 3:
try:
# 使用find_band_number根据波长查找RGB波段
from src.utils.util import find_band_number
rgb_bands = [
find_band_number(650.0, hyperspectral_path), # Red ~650nm
find_band_number(550.0, hyperspectral_path), # Green ~550nm
find_band_number(460.0, hyperspectral_path) # Blue ~460nm
]
print(f" 根据波长选择RGB波段: R={rgb_bands[0]}, G={rgb_bands[1]}, B={rgb_bands[2]}")
except Exception as e:
print(f" 波长查找失败 ({e}),使用默认索引")
# 回退到基于索引的选择
rgb_bands = [min(band_count-1, int(band_count*0.25)),
min(band_count-1, int(band_count*0.15)),
min(band_count-1, int(band_count*0.05))]
else:
rgb_bands = [0, 0, 0]
# 下采样控制 - 用户反馈下采样读取会导致像素值全为0
if downsample and (width > 2000 or height > 2000):
print(f" ⚠ 下采样暂被禁用会导致像素值全0使用原始分辨率: {width}x{height}")
sample_factor = 1
target_width = width
target_height = height
else:
sample_factor = 1
target_width = width
target_height = height
# 只读取需要的RGB波段性能关键优化
rgb_data = []
for band_idx in rgb_bands:
band = dataset.GetRasterBand(band_idx + 1)
# 直接使用完整分辨率读取避免下采样导致像素值为0的问题
band_data = band.ReadAsArray().astype(np.float32)
rgb_data.append(band_data)
# 堆叠为RGB图像 (height, width, 3)
if len(rgb_data) == 3:
image_array = np.stack(rgb_data, axis=2)
else:
# 如果只有1个波段复制为RGB
image_array = np.stack([rgb_data[0]]*3, axis=2)
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
# 释放数据集
dataset = None
# 更新尺寸信息
final_width = target_width if sample_factor > 1 else width
final_height = target_height if sample_factor > 1 else height
print(f" 读取影像: {final_width}x{final_height}x{image_array.shape[2]} (RGB)")
if projection:
proj_type = "投影坐标系" if "PROJCS" in projection else "地理坐标系"
print(f" 影像投影: {proj_type}")
if sample_factor > 1:
print(f" 下采样因子: {sample_factor}")
return image_array, geotransform, projection, final_width, final_height, sample_factor
def _read_sampling_points(self, csv_path: str) -> pd.DataFrame:
"""读取采样点CSV文件"""
if not Path(csv_path).exists():
raise FileNotFoundError(f"CSV文件不存在: {csv_path}")
df = pd.read_csv(csv_path)
# 检查前两列是否为纬度和经度
if len(df.columns) < 2:
raise ValueError("CSV文件至少需要两列纬度、经度")
# 假设前两列是纬度和经度
lat_col = df.columns[0]
lon_col = df.columns[1]
# 重命名列
df = df.rename(columns={lat_col: 'latitude', lon_col: 'longitude'})
# 确保数值类型
df['latitude'] = pd.to_numeric(df['latitude'], errors='coerce')
df['longitude'] = pd.to_numeric(df['longitude'], errors='coerce')
# 删除无效的坐标
df = df.dropna(subset=['latitude', 'longitude'])
print(f"读取到 {len(df)} 个采样点")
return df
def _create_false_color_image(self, image_array: np.ndarray,
rgb_bands: Optional[List[int]] = None) -> np.ndarray:
"""创建假彩色RGB图像 - 应用线性拉伸和Gamma校正"""
# 由于_read_hyperspectral已返回RGB图像这里仅进行最终处理
if image_array.shape[2] != 3:
# 确保是3通道
if len(image_array.shape) == 2 or image_array.shape[2] == 1:
if len(image_array.shape) == 2:
image_array = np.stack([image_array]*3, axis=2)
else:
image_array = np.repeat(image_array, 3, axis=2)
print(f" 处理前图像范围: R[{image_array[:,:,0].min():.3f}-{image_array[:,:,0].max():.3f}], "
f"G[{image_array[:,:,1].min():.3f}-{image_array[:,:,1].max():.3f}], "
f"B[{image_array[:,:,2].min():.3f}-{image_array[:,:,2].max():.3f}]")
# 增强型线性拉伸 - 解决图像太暗的问题
def simple_linear_stretch(data, min_percent=1, max_percent=99):
"""增强对比度的线性拉伸"""
valid_data = data[np.isfinite(data)]
if len(valid_data) == 0:
return np.zeros_like(data, dtype=np.float32)
# 计算百分位数,使用更激进的拉伸 (1%-99%)
p_low = np.percentile(valid_data, min_percent)
p_high = np.percentile(valid_data, max_percent)
if p_high - p_low < 1e-8:
# 如果数据范围太小,使用最小最大值归一化
data_min = valid_data.min()
data_max = valid_data.max()
if data_max > data_min:
stretched = (data - data_min) / (data_max - data_min)
else:
stretched = np.zeros_like(data, dtype=np.float32)
else:
stretched = (data - p_low) / (p_high - p_low)
# 允许轻微过饱和以增加对比度
stretched = np.clip(stretched, 0.0, 1.05)
stretched = np.clip(stretched, 0.0, 1.0) # 最终确保在[0,1]
return stretched
# 对每个通道进行拉伸
r_stretched = simple_linear_stretch(image_array[:, :, 0])
g_stretched = simple_linear_stretch(image_array[:, :, 1])
b_stretched = simple_linear_stretch(image_array[:, :, 2])
# 合成为RGB图像
rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2)
rgb_image = np.nan_to_num(rgb_image, nan=0.0)
# 最终确保范围在[0,1],并轻微增强对比度
rgb_image = np.clip(rgb_image, 0.0, 1.0)
# 可选Gamma校正增加亮度解决太暗问题
gamma = 1 # <1会增加亮度
rgb_image = np.power(rgb_image, gamma)
# 映射到0-255范围uint8这样imshow显示效果更好
rgb_image = (rgb_image * 255).astype(np.uint8)
print(f" 处理后图像范围: [0-255] (Gamma={gamma})")
return rgb_image
def _geo_to_pixel(self, sampling_points: pd.DataFrame,
geotransform: tuple, width: int, height: int,
projection: str = "", sample_factor: int = 1) -> List[Tuple[float, float]]:
"""
使用GDAL进行地理坐标到像素坐标的投影变换 - 支持下采样
原始点位坐标格式: 41.66054612 124.2208338 (WGS84地理坐标: 纬度,经度)
高光谱影像通常使用UTM或其他投影坐标系
当图像下采样时sample_factor > 1需要相应缩放坐标
"""
if geotransform is None or len(sampling_points) == 0:
# 如果没有地理变换信息,使用图像中心
return [(width/2, height/2) for _ in range(len(sampling_points))]
pixel_coords = []
gt = geotransform
# 检查是否需要投影转换
needs_transform = False
if projection and ("PROJCS" in projection or "GEOGCS" in projection):
needs_transform = True
print(f" 检测到影像投影: {projection[:80]}...")
# 创建坐标转换对象WGS84 -> 影像投影)
transform = None
if needs_transform and GDAL_AVAILABLE:
try:
# 源坐标系: WGS84 (EPSG:4326)
src_srs = osr.SpatialReference()
src_srs.ImportFromEPSG(4326) # WGS84
# 目标坐标系: 影像的投影
dst_srs = osr.SpatialReference()
dst_srs.ImportFromWkt(projection)
# 创建坐标转换
transform = osr.CoordinateTransformation(src_srs, dst_srs)
print(" ✓ 已创建WGS84到影像投影的坐标转换")
except Exception as e:
print(f" ⚠ 坐标转换创建失败: {e},使用简化变换")
transform = None
for _, row in sampling_points.iterrows():
lon = float(row['longitude']) # 经度 (WGS84)
lat = float(row['latitude']) # 纬度 (WGS84)
if transform is not None:
# 使用GDAL进行投影转换: (经度, 纬度) -> (投影X, 投影Y)
try:
proj_x, proj_y, _ = transform.TransformPoint(lat, lon)
# 再转换为像素坐标
x = (proj_x - gt[0]) / gt[1]
y = (proj_y - gt[3]) / gt[5]
except Exception as e:
# 转换失败时回退到直接计算
x = (lon - gt[0]) / gt[1]
y = (lat - gt[3]) / gt[5]
else:
# 直接使用仿射变换(坐标系一致的情况)
x = (lon - gt[0]) / gt[1]
y = (lat - gt[3]) / gt[5]
# 如果图像进行了下采样,需要相应缩放坐标
if sample_factor > 1:
x = x / sample_factor
y = y / sample_factor
# 限制在图像范围内(使用下采样后的尺寸)
x = max(0, min(x, width - 1))
y = max(0, min(y, height - 1))
pixel_coords.append((x, y))
if transform is not None:
print(f" ✓ 使用GDAL投影变换处理 {len(pixel_coords)} 个采样点")
else:
print(f" 使用直接仿射变换处理 {len(pixel_coords)} 个采样点")
return pixel_coords
def _create_map_visualization(self, rgb_image: np.ndarray,
pixel_coords: List[Tuple[float, float]],
sampling_points: pd.DataFrame,
output_path: str,
point_color: str,
point_size: int,
point_alpha: float,
show_north_arrow: bool,
show_scale_bar: bool,
show_legend: bool,
dpi: int,
geotransform: tuple,
width: int,
height: int,
downsample: bool = False,
projection: str = "",
sample_factor: int = 1):
"""创建地图可视化 - 优化版"""
# 使用更小的figure尺寸加快渲染
figsize = (10, 8) if self.fast_mode or downsample else (12, 10)
fig, ax = plt.subplots(figsize=figsize, dpi=100 if self.fast_mode else 150)
# 显示假彩色图像 - 现在已经是0-255的uint8格式
print(f" 最终图像数据范围: [{rgb_image.min()}, {rgb_image.max()}] (uint8)")
ax.imshow(rgb_image, interpolation='nearest' if self.fast_mode else 'bilinear')
# 绘制采样点 - 优化使用scatter代替循环plot
if pixel_coords:
x_coords = [p[0] for p in pixel_coords]
y_coords = [p[1] for p in pixel_coords]
ax.scatter(x_coords, y_coords, c=point_color, s=point_size,
alpha=point_alpha, edgecolors='white', linewidth=1.5)
# 添加指北针
if show_north_arrow:
self._add_north_arrow(ax, width, height, position='bottom-left', direction='down')
# 添加比例尺
if show_scale_bar and geotransform is not None:
self._add_scale_bar(ax, geotransform, width, height)
# 添加图例
if show_legend:
legend_text = f'采样点 (n={len(sampling_points)})'
ax.plot([], [], 'o', color=point_color, markersize=8, label=legend_text)
ax.legend(loc='lower right', frameon=True, facecolor='white', edgecolor='gray')
# 设置标题和标签
ax.set_title('高光谱影像采样点分布图', fontsize=16, fontweight='bold', pad=20)
# 隐藏坐标轴刻度
ax.set_xticks([])
ax.set_yticks([])
# 添加网格
ax.grid(True, alpha=0.2, linestyle='--')
plt.tight_layout()
# 保存参数 - 避免传递不兼容的参数
save_kwargs = {
'dpi': dpi,
'bbox_inches': 'tight',
'pad_inches': 0.05,
'facecolor': 'white'
}
# 仅添加matplotlib支持的参数
if self.fast_mode:
save_kwargs['dpi'] = min(dpi, 180) # 快速模式降低DPI
plt.savefig(output_path, **save_kwargs)
plt.close(fig)
def _add_north_arrow(self, ax, width: int, height: int, position='top-right', direction='down',
size=0.08, color='white', n_color='white', outline_color='black'):
"""
添加指北针,可配置位置、方向、大小、颜色。
参数:
ax: matplotlib Axes对象
width, height: 图像宽高(用于相对定位)
position: 'top-left', 'top-right', 'bottom-left', 'bottom-right'
direction: 'up', 'down', 'left', 'right' 箭头指向
size: 箭头长度相对于高度的比例0.05~0.12
color: 箭头颜色
n_color: 'N' 文字颜色
outline_color: 文字描边颜色
"""
# 位置映射(偏移系数)
pos_map = {
'top-left': (0.08, 0.88),
'top-right': (0.92, 0.88),
'bottom-left': (0.08, 0.12),
'bottom-right': (0.92, 0.12),
}
arrow_x_ratio, arrow_y_ratio = pos_map.get(position, (0.92, 0.88))
arrow_x = width * arrow_x_ratio
arrow_y = height * arrow_y_ratio
# 方向映射(箭头终点偏移)
direction_map = {
'up': (0, +size),
'down': (0, -size),
'left': (-size, 0),
'right': (+size, 0),
}
dx, dy = direction_map.get(direction, (0, -size))
end_x = arrow_x + dx * width # 注意dx是比例乘以宽度/高度保持比例一致
end_y = arrow_y + dy * height
# 箭头绘制
arrow = FancyArrowPatch((arrow_x, arrow_y), (end_x, end_y),
color=color, linewidth=3,
arrowstyle='->', mutation_scale=20)
ax.add_patch(arrow)
# N 文字位置:在箭头尾部或头部?通常放在箭头指向的反方向末端
# 这里放在箭头尾部向外偏移一点(便于阅读)
# 偏移系数根据方向决定
offset_scale = 0.02 # 偏移量比例
if direction == 'up':
text_x = arrow_x
text_y = arrow_y - height * offset_scale # 放在箭头下方
elif direction == 'down':
text_x = arrow_x
text_y = arrow_y + height * offset_scale # 放在箭头上方
elif direction == 'left':
text_x = arrow_x + width * offset_scale
text_y = arrow_y
else: # right
text_x = arrow_x - width * offset_scale
text_y = arrow_y
ax.text(text_x, text_y, 'N', fontsize=14, fontweight='bold',
color=n_color, ha='center', va='center',
path_effects=[path_effects.withStroke(linewidth=3, foreground=outline_color)])
def _add_scale_bar(self, ax, geotransform: tuple, width: int, height: int):
"""添加比例尺"""
if geotransform is None:
return
# 计算图像实际宽度(米)
pixel_size_x = abs(geotransform[1])
image_width_meters = width * pixel_size_x
# 选择合适的比例尺长度图像宽度的1/4
scale_length_m = image_width_meters / 4
scale_length_pixels = width / 4
# 找到合适的刻度
scale_options = [1000, 500, 200, 100, 50, 20, 10, 5, 2, 1]
scale_meters = next((s for s in scale_options if s <= scale_length_m), 1)
scale_pixels = int(scale_meters / pixel_size_x)
# 在左下角添加比例尺
bar_x = width * 0.08
bar_y = height * 0.92
# 绘制比例尺线
ax.plot([bar_x, bar_x + scale_pixels], [bar_y, bar_y], color='white', linewidth=4)
# 添加刻度线
ax.plot([bar_x, bar_x], [bar_y, bar_y + 8], color='white', linewidth=2)
ax.plot([bar_x + scale_pixels, bar_x + scale_pixels], [bar_y, bar_y + 8], color='white', linewidth=2)
# 添加文字
ax.text(bar_x + scale_pixels/2, bar_y , f'{scale_meters} m',
fontsize=11, ha='center', va='bottom', fontweight='bold',
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
def batch_create_maps(self, hyperspectral_path: str,
csv_folder: str,
output_subdir: str = "sampling_maps",
fast_mode: bool = True) -> Dict[str, str]:
"""
批量创建采样点地图
Args:
hyperspectral_path: 高光谱影像路径
csv_folder: 包含多个CSV文件的文件夹
output_subdir: 输出子目录
Returns:
生成的地图文件路径字典
"""
csv_folder_path = Path(csv_folder)
if not csv_folder_path.exists():
raise FileNotFoundError(f"CSV文件夹不存在: {csv_folder}")
# 创建输出目录
output_dir = self.output_dir / output_subdir
output_dir.mkdir(parents=True, exist_ok=True)
map_paths = {}
# 查找所有CSV文件
csv_files = list(csv_folder_path.glob("*.csv"))
print(f"找到 {len(csv_files)} 个CSV文件开始批量生成采样点地图... (快速模式: {fast_mode})")
for csv_file in csv_files:
try:
output_filename = f"{Path(hyperspectral_path).stem}_{csv_file.stem}_sampling_map.png"
map_path = self.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=str(csv_file),
output_filename=output_filename,
downsample=True, # 批量模式默认下采样
dpi=120 if fast_mode else 200
)
map_paths[csv_file.name] = map_path
print(f"✓ 生成: {csv_file.name}")
except Exception as e:
print(f"✗ 处理 {csv_file.name} 失败: {e}")
print(f"批量生成完成,共生成 {len(map_paths)} 个采样点地图")
return map_paths
# 测试代码
if __name__ == "__main__":
# 示例用法
map_generator = SamplingPointMap(output_dir="./point_maps")
# 测试代码已禁用,避免直接运行时出错
map_generator_fast = SamplingPointMap(output_dir="./point_maps", fast_mode=True)
map_path = map_generator_fast.create_sampling_point_map(
hyperspectral_path=r"D:\BaiduNetdiskDownload\yaobao\result3.bsq",
csv_path=r"E:\code\WQ\pipeline_result\work_dir\4_processed_data\processed_data.csv",
downsample=True,
dpi=150
)
print("测试代码已注释请通过GUI或手动调用使用。")
print("SamplingPointMap类已创建可以用于生成带采样点的地图。")
print("性能优化功能:")
print(" - fast_mode=True: 快速模式 (推荐用于预览)")
print(" - downsample=True: 对大影像下采样 (推荐用于>2000x2000影像)")
print(" - 使用: SamplingPointMap(fast_mode=True).create_sampling_point_map(...)")

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff