545 lines
21 KiB
Python
545 lines
21 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
飞行轨迹可视化模块 - 在多架次GPS数据上绘制飞行轨迹
|
||
|
||
支持功能:
|
||
1. 读取多个.gps文件(每个文件代表一个架次)
|
||
2. 在高光谱假彩色影像上绘制飞行轨迹
|
||
3. 不同架次使用不同颜色
|
||
4. 图例显示架次起始到结束的时间段
|
||
5. 添加指北针、比例尺
|
||
6. 保存为PNG图像
|
||
"""
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
from pathlib import Path
|
||
from typing import Optional, Tuple, List, Dict, Union
|
||
import warnings
|
||
from matplotlib.patches import FancyArrowPatch
|
||
import matplotlib.patheffects as path_effects
|
||
from datetime import datetime
|
||
from matplotlib.colors import ListedColormap
|
||
import matplotlib.patches as mpatches
|
||
|
||
# 性能优化配置
|
||
plt.rcParams['agg.path.chunksize'] = 10000
|
||
plt.rcParams['path.simplify'] = True
|
||
plt.rcParams['path.simplify_threshold'] = 0.1
|
||
|
||
# 导入GDAL用于影像读写
|
||
try:
|
||
from osgeo import gdal, osr
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
print("警告: GDAL未安装,地理坐标转换功能可能无法正常工作")
|
||
|
||
|
||
class FlightPathVisualizer:
|
||
"""飞行轨迹可视化类 - 在高光谱假彩色影像上绘制多架次飞行轨迹"""
|
||
|
||
# 预定义颜色方案(不同架次使用不同颜色)
|
||
FLIGHT_COLORS = [
|
||
'#FF0000', # 红色
|
||
'#00FF00', # 绿色
|
||
'#0000FF', # 蓝色
|
||
'#FF00FF', # 紫色
|
||
'#00FFFF', # 青色
|
||
'#FFFF00', # 黄色
|
||
'#FF8000', # 橙色
|
||
'#8000FF', # 紫罗兰
|
||
'#0080FF', # 天蓝
|
||
'#FF0080', # 粉红
|
||
]
|
||
|
||
def __init__(self, output_dir: str = "./flight_paths"):
|
||
"""
|
||
初始化飞行轨迹可视化器
|
||
|
||
Args:
|
||
output_dir: 输出目录,用于保存生成的轨迹图
|
||
"""
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 设置中文字体
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
plt.rcParams['font.size'] = 12
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
def create_flight_path_map(self,
|
||
gps_folder: str,
|
||
hyperspectral_path: str,
|
||
output_filename: Optional[str] = None,
|
||
rgb_bands: Optional[List[int]] = None,
|
||
line_width: int = 2,
|
||
show_north_arrow: bool = True,
|
||
show_scale_bar: bool = True,
|
||
dpi: int = 300) -> str:
|
||
"""
|
||
创建飞行轨迹地图:在高光谱假彩色影像上绘制多架次飞行轨迹
|
||
|
||
Args:
|
||
gps_folder: GPS文件夹路径,包含多个.gps文件
|
||
hyperspectral_path: 高光谱影像文件路径 (.dat, .bsq, .tif等)
|
||
output_filename: 输出文件名(如果为None则自动生成)
|
||
rgb_bands: 用于RGB合成的三个波段索引 [R, G, B],默认为None自动选择(650,550,460nm)
|
||
line_width: 轨迹线宽
|
||
show_north_arrow: 是否显示指北针
|
||
show_scale_bar: 是否显示比例尺
|
||
dpi: 输出图像分辨率
|
||
|
||
Returns:
|
||
生成的地图文件路径
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法处理地理坐标转换")
|
||
|
||
print(f"正在生成飞行轨迹地图...")
|
||
|
||
# 读取高光谱影像
|
||
hyperspectral_img, geotransform, projection, width, height = self._read_hyperspectral(
|
||
hyperspectral_path, rgb_bands)
|
||
|
||
# 读取所有GPS文件
|
||
flight_data = self._read_gps_files(gps_folder)
|
||
|
||
if not flight_data:
|
||
raise ValueError(f"未在 {gps_folder} 中找到有效的.gps文件")
|
||
|
||
# 将GPS坐标转换为像素坐标
|
||
flight_pixels = self._convert_flights_to_pixels(
|
||
flight_data, geotransform, width, height, projection)
|
||
|
||
# 创建地图
|
||
if output_filename is None:
|
||
folder_name = Path(gps_folder).name
|
||
hs_name = Path(hyperspectral_path).stem
|
||
output_filename = f"{hs_name}_{folder_name}_flight_paths.png"
|
||
|
||
output_path = self.output_dir / output_filename
|
||
|
||
self._create_map_visualization(
|
||
hyperspectral_img, flight_pixels, flight_data,
|
||
str(output_path), line_width,
|
||
show_north_arrow, show_scale_bar, dpi,
|
||
geotransform, width, height
|
||
)
|
||
|
||
print(f"飞行轨迹地图已保存: {output_path}")
|
||
return str(output_path)
|
||
|
||
def _read_hyperspectral(self, hyperspectral_path: str,
|
||
rgb_bands: Optional[List[int]] = None) -> Tuple[np.ndarray, tuple, str, int, int]:
|
||
"""读取高光谱影像 - 使用650/550/460nm波长"""
|
||
dataset = gdal.Open(hyperspectral_path)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开高光谱影像: {hyperspectral_path}")
|
||
|
||
width = dataset.RasterXSize
|
||
height = dataset.RasterYSize
|
||
band_count = dataset.RasterCount
|
||
|
||
# 确定要读取的波段 - 使用指定波长 650, 550, 460nm
|
||
if rgb_bands is None:
|
||
if band_count >= 3:
|
||
try:
|
||
from src.utils.util import find_band_number
|
||
rgb_bands = [
|
||
find_band_number(650.0, hyperspectral_path), # Red ~650nm
|
||
find_band_number(550.0, hyperspectral_path), # Green ~550nm
|
||
find_band_number(460.0, hyperspectral_path) # Blue ~460nm
|
||
]
|
||
print(f" 根据波长选择RGB波段: R={rgb_bands[0]}, G={rgb_bands[1]}, B={rgb_bands[2]}")
|
||
except Exception as e:
|
||
print(f" 波长查找失败 ({e}),使用默认索引")
|
||
rgb_bands = [min(band_count-1, int(band_count*0.25)),
|
||
min(band_count-1, int(band_count*0.15)),
|
||
min(band_count-1, int(band_count*0.05))]
|
||
else:
|
||
rgb_bands = [0, 0, 0]
|
||
|
||
# 读取RGB波段
|
||
rgb_data = []
|
||
for band_idx in rgb_bands:
|
||
band = dataset.GetRasterBand(band_idx + 1)
|
||
band_data = band.ReadAsArray().astype(np.float32)
|
||
rgb_data.append(band_data)
|
||
|
||
# 堆叠为RGB图像
|
||
if len(rgb_data) == 3:
|
||
image_array = np.stack(rgb_data, axis=2)
|
||
else:
|
||
image_array = np.stack([rgb_data[0]]*3, axis=2)
|
||
|
||
geotransform = dataset.GetGeoTransform()
|
||
projection = dataset.GetProjection()
|
||
dataset = None
|
||
|
||
print(f" 读取影像: {width}x{height}x{image_array.shape[2]} (RGB)")
|
||
if projection:
|
||
proj_type = "投影坐标系" if "PROJCS" in projection else "地理坐标系"
|
||
print(f" 影像投影: {proj_type}")
|
||
|
||
return image_array, geotransform, projection, width, height
|
||
|
||
def _read_gps_files(self, gps_folder: str) -> Dict[str, pd.DataFrame]:
|
||
"""
|
||
读取文件夹中的所有.gps文件
|
||
|
||
文件格式: 日期、时间、三个姿态角、经度、纬度、高程
|
||
列: [date, time, pitch, roll, yaw, longitude, latitude, altitude]
|
||
"""
|
||
gps_folder_path = Path(gps_folder)
|
||
if not gps_folder_path.exists():
|
||
raise FileNotFoundError(f"GPS文件夹不存在: {gps_folder}")
|
||
|
||
gps_files = list(gps_folder_path.glob("*.gps"))
|
||
if not gps_files:
|
||
print(f"警告: 在 {gps_folder} 中未找到.gps文件")
|
||
return {}
|
||
|
||
print(f"找到 {len(gps_files)} 个GPS文件")
|
||
|
||
flight_data = {}
|
||
for gps_file in gps_files:
|
||
try:
|
||
# 读取GPS文件(制表符分隔)
|
||
df = pd.read_csv(gps_file, sep='\t', header=None,
|
||
names=['date', 'time', 'pitch', 'roll', 'yaw',
|
||
'longitude', 'latitude', 'altitude'])
|
||
|
||
# 确保数值类型
|
||
df['longitude'] = pd.to_numeric(df['longitude'], errors='coerce')
|
||
df['latitude'] = pd.to_numeric(df['latitude'], errors='coerce')
|
||
df['altitude'] = pd.to_numeric(df['altitude'], errors='coerce')
|
||
|
||
# 删除无效坐标
|
||
df = df.dropna(subset=['longitude', 'latitude'])
|
||
|
||
if len(df) > 0:
|
||
flight_data[gps_file.stem] = df
|
||
print(f" ✓ 读取 {gps_file.name}: {len(df)} 个轨迹点")
|
||
# 显示时间范围
|
||
start_time = df.iloc[0]['time']
|
||
end_time = df.iloc[-1]['time']
|
||
print(f" 时间范围: {start_time} - {end_time}")
|
||
else:
|
||
print(f" ✗ {gps_file.name}: 无有效数据")
|
||
|
||
except Exception as e:
|
||
print(f" ✗ 读取 {gps_file.name} 失败: {e}")
|
||
|
||
return flight_data
|
||
|
||
def _convert_flights_to_pixels(self, flight_data: Dict[str, pd.DataFrame],
|
||
geotransform: tuple, width: int, height: int,
|
||
projection: str = "") -> Dict[str, List[Tuple[float, float]]]:
|
||
"""将所有飞行轨迹的地理坐标转换为像素坐标"""
|
||
if geotransform is None:
|
||
print("警告: 无地理变换信息,无法转换坐标")
|
||
return {}
|
||
|
||
gt = geotransform
|
||
flight_pixels = {}
|
||
|
||
# 检查是否需要投影转换
|
||
needs_transform = projection and ("PROJCS" in projection or "GEOGCS" in projection)
|
||
transform = None
|
||
|
||
if needs_transform and GDAL_AVAILABLE:
|
||
try:
|
||
src_srs = osr.SpatialReference()
|
||
src_srs.ImportFromEPSG(4326) # WGS84
|
||
dst_srs = osr.SpatialReference()
|
||
dst_srs.ImportFromWkt(projection)
|
||
transform = osr.CoordinateTransformation(src_srs, dst_srs)
|
||
print(" ✓ 已创建WGS84到影像投影的坐标转换")
|
||
except Exception as e:
|
||
print(f" ⚠ 坐标转换创建失败: {e}")
|
||
|
||
for flight_name, df in flight_data.items():
|
||
pixel_coords = []
|
||
|
||
for _, row in df.iterrows():
|
||
lon = float(row['longitude'])
|
||
lat = float(row['latitude'])
|
||
|
||
if transform is not None:
|
||
try:
|
||
proj_x, proj_y, _ = transform.TransformPoint(lat, lon)
|
||
x = (proj_x - gt[0]) / gt[1]
|
||
y = (proj_y - gt[3]) / gt[5]
|
||
except:
|
||
x = (lon - gt[0]) / gt[1]
|
||
y = (lat - gt[3]) / gt[5]
|
||
else:
|
||
x = (lon - gt[0]) / gt[1]
|
||
y = (lat - gt[3]) / gt[5]
|
||
|
||
# 限制在图像范围内
|
||
x = max(0, min(x, width - 1))
|
||
y = max(0, min(y, height - 1))
|
||
|
||
pixel_coords.append((x, y))
|
||
|
||
flight_pixels[flight_name] = pixel_coords
|
||
|
||
print(f" 已转换 {len(flight_pixels)} 个架次的坐标")
|
||
return flight_pixels
|
||
|
||
def _create_false_color_image(self, image_array: np.ndarray) -> np.ndarray:
|
||
"""创建假彩色RGB图像 - 应用线性拉伸"""
|
||
if image_array.shape[2] != 3:
|
||
if len(image_array.shape) == 2 or image_array.shape[2] == 1:
|
||
if len(image_array.shape) == 2:
|
||
image_array = np.stack([image_array]*3, axis=2)
|
||
else:
|
||
image_array = np.repeat(image_array, 3, axis=2)
|
||
|
||
def simple_linear_stretch(data, min_percent=1, max_percent=99):
|
||
valid_data = data[np.isfinite(data)]
|
||
if len(valid_data) == 0:
|
||
return np.zeros_like(data, dtype=np.float32)
|
||
|
||
p_low = np.percentile(valid_data, min_percent)
|
||
p_high = np.percentile(valid_data, max_percent)
|
||
|
||
if p_high - p_low < 1e-8:
|
||
data_min = valid_data.min()
|
||
data_max = valid_data.max()
|
||
if data_max > data_min:
|
||
stretched = (data - data_min) / (data_max - data_min)
|
||
else:
|
||
stretched = np.zeros_like(data, dtype=np.float32)
|
||
else:
|
||
stretched = (data - p_low) / (p_high - p_low)
|
||
|
||
stretched = np.clip(stretched, 0.0, 1.0)
|
||
return stretched
|
||
|
||
r_stretched = simple_linear_stretch(image_array[:, :, 0])
|
||
g_stretched = simple_linear_stretch(image_array[:, :, 1])
|
||
b_stretched = simple_linear_stretch(image_array[:, :, 2])
|
||
|
||
rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2)
|
||
rgb_image = np.nan_to_num(rgb_image, nan=0.0)
|
||
rgb_image = np.clip(rgb_image, 0.0, 1.0)
|
||
|
||
# Gamma校正增加亮度
|
||
gamma = 0.85
|
||
rgb_image = np.power(rgb_image, gamma)
|
||
|
||
# 映射到0-255
|
||
rgb_image = (rgb_image * 255).astype(np.uint8)
|
||
|
||
return rgb_image
|
||
|
||
def _create_map_visualization(self, image_array: np.ndarray,
|
||
flight_pixels: Dict[str, List[Tuple[float, float]]],
|
||
flight_data: Dict[str, pd.DataFrame],
|
||
output_path: str,
|
||
line_width: int,
|
||
show_north_arrow: bool,
|
||
show_scale_bar: bool,
|
||
dpi: int,
|
||
geotransform: tuple,
|
||
width: int,
|
||
height: int):
|
||
"""创建地图可视化"""
|
||
figsize = (14, 10)
|
||
fig, ax = plt.subplots(figsize=figsize, dpi=150)
|
||
|
||
# 处理背景图像
|
||
rgb_image = self._create_false_color_image(image_array)
|
||
ax.imshow(rgb_image, interpolation='bilinear')
|
||
|
||
# 绘制飞行轨迹 - 不同架次不同颜色
|
||
legend_elements = []
|
||
|
||
for idx, (flight_name, pixel_coords) in enumerate(flight_pixels.items()):
|
||
if len(pixel_coords) < 2:
|
||
continue
|
||
|
||
# 选择颜色
|
||
color = self.FLIGHT_COLORS[idx % len(self.FLIGHT_COLORS)]
|
||
|
||
# 提取x,y坐标
|
||
x_coords = [p[0] for p in pixel_coords]
|
||
y_coords = [p[1] for p in pixel_coords]
|
||
|
||
# 绘制轨迹线
|
||
ax.plot(x_coords, y_coords, color=color, linewidth=line_width,
|
||
alpha=0.8, solid_capstyle='round')
|
||
|
||
# 标记起点和终点
|
||
ax.plot(x_coords[0], y_coords[0], 'o', color=color, markersize=8,
|
||
markeredgecolor='white', markeredgewidth=1)
|
||
ax.plot(x_coords[-1], y_coords[-1], 's', color=color, markersize=8,
|
||
markeredgecolor='white', markeredgewidth=1)
|
||
|
||
# 获取时间范围用于图例
|
||
df = flight_data[flight_name]
|
||
start_time = df.iloc[0]['time']
|
||
end_time = df.iloc[-1]['time']
|
||
|
||
# 创建图例元素
|
||
legend_label = f"{flight_name}: {start_time} - {end_time}"
|
||
legend_elements.append(
|
||
mpatches.Patch(color=color, label=legend_label)
|
||
)
|
||
|
||
# 添加图例
|
||
if legend_elements:
|
||
ax.legend(handles=legend_elements, loc='lower right',
|
||
frameon=True, facecolor='white', edgecolor='gray',
|
||
fontsize=9, title='飞行轨迹 (起点→终点)')
|
||
|
||
# 添加指北针
|
||
if show_north_arrow:
|
||
self._add_north_arrow(ax, width, height)
|
||
|
||
# 添加比例尺
|
||
if show_scale_bar and geotransform is not None:
|
||
self._add_scale_bar(ax, geotransform, width, height)
|
||
|
||
# 设置标题
|
||
ax.set_title('多架次飞行轨迹图', fontsize=16, fontweight='bold', pad=20)
|
||
|
||
|
||
# 隐藏坐标轴刻度
|
||
ax.set_xticks([])
|
||
ax.set_yticks([])
|
||
|
||
# 添加网格
|
||
ax.grid(True, alpha=0.2, linestyle='--')
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, dpi=dpi, bbox_inches='tight', pad_inches=0.1, facecolor='white')
|
||
plt.close(fig)
|
||
|
||
def _add_north_arrow(self, ax, width: int, height: int):
|
||
"""添加指北针"""
|
||
arrow_x = width * 0.92
|
||
arrow_y = height * 0.88
|
||
|
||
arrow = FancyArrowPatch((arrow_x, arrow_y), (arrow_x, arrow_y - height*0.08),
|
||
color='black', linewidth=3, arrowstyle='->', mutation_scale=20)
|
||
ax.add_patch(arrow)
|
||
|
||
ax.text(arrow_x, arrow_y - height*0.1, 'N', fontsize=14, fontweight='bold',
|
||
color='black', ha='center', va='center',
|
||
path_effects=[path_effects.withStroke(linewidth=3, foreground='white')])
|
||
|
||
def _add_scale_bar(self, ax, geotransform: tuple, width: int, height: int):
|
||
"""添加比例尺"""
|
||
if geotransform is None:
|
||
return
|
||
|
||
pixel_size_x = abs(geotransform[1])
|
||
image_width_meters = width * pixel_size_x
|
||
scale_length_m = image_width_meters / 4
|
||
|
||
scale_options = [1000, 500, 200, 100, 50, 20, 10, 5, 2, 1]
|
||
scale_meters = next((s for s in scale_options if s <= scale_length_m), 1)
|
||
scale_pixels = int(scale_meters / pixel_size_x)
|
||
|
||
bar_x = width * 0.08
|
||
bar_y = height * 0.92
|
||
|
||
ax.plot([bar_x, bar_x + scale_pixels], [bar_y, bar_y], color='black', linewidth=4)
|
||
ax.plot([bar_x, bar_x], [bar_y, bar_y + 8], color='black', linewidth=2)
|
||
ax.plot([bar_x + scale_pixels, bar_x + scale_pixels], [bar_y, bar_y + 8], color='black', linewidth=2)
|
||
|
||
ax.text(bar_x + scale_pixels/2, bar_y + 15, f'{scale_meters} m',
|
||
fontsize=11, ha='center', va='bottom', fontweight='bold',
|
||
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
|
||
|
||
def batch_create_maps(self, gps_folder: str,
|
||
hyperspectral_folder: str,
|
||
output_subdir: str = "flight_paths") -> Dict[str, str]:
|
||
"""
|
||
批量创建飞行轨迹地图
|
||
|
||
Args:
|
||
gps_folder: 包含多个子文件夹(每个子文件夹包含.gps文件)的文件夹
|
||
hyperspectral_folder: 包含高光谱影像的文件夹
|
||
output_subdir: 输出子目录
|
||
|
||
Returns:
|
||
生成的地图文件路径字典
|
||
"""
|
||
gps_folder_path = Path(gps_folder)
|
||
hs_folder_path = Path(hyperspectral_folder)
|
||
|
||
if not gps_folder_path.exists():
|
||
raise FileNotFoundError(f"GPS文件夹不存在: {gps_folder}")
|
||
if not hs_folder_path.exists():
|
||
raise FileNotFoundError(f"高光谱文件夹不存在: {hyperspectral_folder}")
|
||
|
||
output_dir = self.output_dir / output_subdir
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
map_paths = {}
|
||
|
||
# 查找所有包含.gps文件的子文件夹
|
||
gps_subfolders = [d for d in gps_folder_path.iterdir() if d.is_dir() and list(d.glob("*.gps"))]
|
||
|
||
# 查找高光谱影像
|
||
hs_files = []
|
||
for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']:
|
||
hs_files.extend(list(hs_folder_path.glob(ext)))
|
||
|
||
if not hs_files:
|
||
print(f"警告: 在 {hyperspectral_folder} 中未找到高光谱影像")
|
||
return map_paths
|
||
|
||
print(f"找到 {len(gps_subfolders)} 个GPS子文件夹和 {len(hs_files)} 个高光谱影像")
|
||
|
||
# 简单匹配:使用第一个高光谱影像与所有GPS文件夹组合
|
||
hs_file = hs_files[0]
|
||
|
||
for gps_subfolder in gps_subfolders:
|
||
try:
|
||
output_filename = f"{hs_file.stem}_{gps_subfolder.name}_flight_paths.png"
|
||
map_path = self.create_flight_path_map(
|
||
gps_folder=str(gps_subfolder),
|
||
hyperspectral_path=str(hs_file),
|
||
output_filename=output_filename,
|
||
dpi=200
|
||
)
|
||
map_paths[gps_subfolder.name] = map_path
|
||
print(f"✓ 生成: {gps_subfolder.name}")
|
||
|
||
except Exception as e:
|
||
print(f"✗ 处理 {gps_subfolder.name} 失败: {e}")
|
||
|
||
print(f"批量生成完成,共生成 {len(map_paths)} 个飞行轨迹图")
|
||
return map_paths
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == "__main__":
|
||
print("FlightPathVisualizer类已创建")
|
||
print("用法示例:")
|
||
print(" visualizer = FlightPathVisualizer(output_dir='./flight_maps')")
|
||
print(" map_path = visualizer.create_flight_path_map(")
|
||
print(" gps_folder='./gps_data',")
|
||
print(" hyperspectral_path='./hyperspectral.dat'")
|
||
print(" )")
|
||
|
||
|
||
visualizer = FlightPathVisualizer(output_dir=r"E:\code\WQ\pipeline_result\work_dir\14_visualization\flight_maps")
|
||
# 生成飞行轨迹图
|
||
map_path = visualizer.create_flight_path_map(
|
||
gps_folder=r"D:\BaiduNetdiskDownload\20250902\gps", # GPS文件夹路径
|
||
hyperspectral_path=r"E:\code\WQ\pipeline_result\work_dir\3_deglint\deglint_goodman.bsq", # 高光谱影像路径
|
||
output_filename="flight_paths.png",
|
||
line_width=2,
|
||
dpi=300
|
||
) |