Initial commit of WQ_GUI
This commit is contained in:
1
src/postprocessing/__init__.py
Normal file
1
src/postprocessing/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
327
src/postprocessing/box_plot.py
Normal file
327
src/postprocessing/box_plot.py
Normal file
@ -0,0 +1,327 @@
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import os
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
def plot_individual_boxplots(csv_file_path, save_dir="boxplots"):
|
||||
"""
|
||||
为每个数据列单独绘制箱型图并保存
|
||||
|
||||
参数:
|
||||
csv_file_path: CSV文件路径
|
||||
save_dir: 保存图片的目录
|
||||
"""
|
||||
try:
|
||||
# 读取CSV文件
|
||||
df = pd.read_csv(csv_file_path)
|
||||
|
||||
# 获取第五列之后的数据列(索引从0开始,第五列索引为4)
|
||||
data_columns = df.iloc[:, 4:]
|
||||
|
||||
# 检查是否有数据列
|
||||
if data_columns.empty:
|
||||
print("错误:CSV文件中没有足够的列(至少需要5列)")
|
||||
return
|
||||
|
||||
# 创建保存目录
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
print(f"创建目录: {save_dir}")
|
||||
|
||||
# 为每个数据列单独绘制箱型图
|
||||
for column in data_columns.columns:
|
||||
# 移除空值
|
||||
clean_data = data_columns[column].dropna()
|
||||
|
||||
if len(clean_data) == 0:
|
||||
print(f"跳过列 '{column}': 没有有效数据")
|
||||
continue
|
||||
|
||||
# 创建新图形
|
||||
plt.figure(figsize=(8, 6))
|
||||
|
||||
# 绘制箱型图
|
||||
box_plot = plt.boxplot([clean_data], labels=[column], patch_artist=True,
|
||||
showfliers=False)
|
||||
|
||||
# 美化箱型图
|
||||
box_plot['boxes'][0].set_facecolor('lightblue')
|
||||
box_plot['boxes'][0].set_alpha(0.7)
|
||||
|
||||
# 添加散点
|
||||
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
|
||||
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
|
||||
# 设置标题和标签
|
||||
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
|
||||
plt.xlabel('数据列', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
|
||||
# 添加统计信息到图上
|
||||
stats_text = f'数据点数: {len(clean_data)}\n均值: {clean_data.mean():.2f}\n中位数: {clean_data.median():.2f}\n标准差: {clean_data.std():.2f}'
|
||||
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||||
|
||||
# 添加网格
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
|
||||
save_path = os.path.join(save_dir, f'{safe_column_name}_boxplot.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"已保存: {save_path}")
|
||||
|
||||
# 关闭图形以释放内存
|
||||
plt.close()
|
||||
|
||||
print(f"\n所有箱型图已保存到目录: {save_dir}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:找不到文件 {csv_file_path}")
|
||||
except Exception as e:
|
||||
print(f"错误:{str(e)}")
|
||||
|
||||
def plot_individual_boxplots_seaborn(csv_file_path, save_dir="boxplots_seaborn"):
|
||||
"""
|
||||
使用seaborn为每个数据列单独绘制箱型图并保存
|
||||
|
||||
参数:
|
||||
csv_file_path: CSV文件路径
|
||||
save_dir: 保存图片的目录
|
||||
"""
|
||||
try:
|
||||
# 读取CSV文件
|
||||
df = pd.read_csv(csv_file_path)
|
||||
|
||||
# 获取第五列之后的数据列
|
||||
data_columns = df.iloc[:, 4:]
|
||||
|
||||
if data_columns.empty:
|
||||
print("错误:CSV文件中没有足够的列(至少需要5列)")
|
||||
return
|
||||
|
||||
# 创建保存目录
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
print(f"创建目录: {save_dir}")
|
||||
|
||||
# 为每个数据列单独绘制箱型图
|
||||
for column in data_columns.columns:
|
||||
# 移除空值
|
||||
clean_data = data_columns[column].dropna()
|
||||
|
||||
if len(clean_data) == 0:
|
||||
print(f"跳过列 '{column}': 没有有效数据")
|
||||
continue
|
||||
|
||||
# 创建新图形
|
||||
plt.figure(figsize=(8, 6))
|
||||
|
||||
# 创建数据框用于seaborn
|
||||
plot_data = pd.DataFrame({
|
||||
'列名': [column] * len(clean_data),
|
||||
'数值': clean_data
|
||||
})
|
||||
|
||||
# 使用seaborn绘制箱型图和散点
|
||||
sns.boxplot(data=plot_data, x='列名', y='数值', palette='Set2')
|
||||
sns.stripplot(data=plot_data, x='列名', y='数值',
|
||||
color='red', alpha=0.6, size=5, jitter=True)
|
||||
|
||||
# 设置标题和标签
|
||||
plt.title(f'{column} - 箱型图 (Seaborn)', fontsize=14, fontweight='bold')
|
||||
plt.xlabel('数据列', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
|
||||
# 添加统计信息
|
||||
stats_text = f'数据点数: {len(clean_data)}\n均值: {clean_data.mean():.2f}\n中位数: {clean_data.median():.2f}\n标准差: {clean_data.std():.2f}'
|
||||
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
|
||||
|
||||
# 添加网格
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
|
||||
save_path = os.path.join(save_dir, f'{safe_column_name}_boxplot_seaborn.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"已保存: {save_path}")
|
||||
|
||||
# 关闭图形以释放内存
|
||||
plt.close()
|
||||
|
||||
print(f"\n所有箱型图已保存到目录: {save_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误:{str(e)}")
|
||||
|
||||
def plot_boxplot_with_scatter(csv_file_path):
|
||||
"""
|
||||
读取CSV文件并绘制第五列之后数据列的箱型图,同时标注散点
|
||||
|
||||
参数:
|
||||
csv_file_path: CSV文件路径
|
||||
"""
|
||||
try:
|
||||
# 读取CSV文件
|
||||
df = pd.read_csv(csv_file_path)
|
||||
|
||||
# 获取第五列之后的数据列(索引从0开始,第五列索引为4)
|
||||
data_columns = df.iloc[:, 4:] # 从第五列开始的所有列
|
||||
|
||||
# 检查是否有数据列
|
||||
if data_columns.empty:
|
||||
print("错误:CSV文件中没有足够的列(至少需要5列)")
|
||||
return
|
||||
|
||||
# 设置图形大小
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# 准备数据用于绘制箱型图
|
||||
box_data = []
|
||||
labels = []
|
||||
|
||||
for column in data_columns.columns:
|
||||
# 移除空值
|
||||
clean_data = data_columns[column].dropna()
|
||||
if len(clean_data) > 0:
|
||||
box_data.append(clean_data)
|
||||
labels.append(column)
|
||||
|
||||
# 绘制箱型图
|
||||
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True,
|
||||
showfliers=False) # 不显示异常值点,因为我们要自己绘制散点
|
||||
|
||||
# 美化箱型图
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
|
||||
for patch, color in zip(box_plot['boxes'], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
|
||||
# 在每个箱型图上添加散点
|
||||
for i, data in enumerate(box_data):
|
||||
# 为每个数据点添加一些随机的x轴偏移,避免重叠
|
||||
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
|
||||
|
||||
# 绘制散点
|
||||
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
|
||||
# 设置标题和标签
|
||||
plt.title('数据列箱型图(带散点标注)', fontsize=16, fontweight='bold')
|
||||
plt.xlabel('数据列', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
|
||||
# 旋转x轴标签以避免重叠
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
|
||||
# 添加网格
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示图形
|
||||
plt.show()
|
||||
|
||||
# 打印统计信息
|
||||
print(f"成功绘制了 {len(labels)} 个数据列的箱型图")
|
||||
print("数据列名称:", labels)
|
||||
|
||||
# 显示每列的基本统计信息
|
||||
print("\n各列基本统计信息:")
|
||||
for column in labels:
|
||||
data = data_columns[column].dropna()
|
||||
print(f"{column}: 数据点数={len(data)}, 均值={data.mean():.2f}, 中位数={data.median():.2f}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:找不到文件 {csv_file_path}")
|
||||
except Exception as e:
|
||||
print(f"错误:{str(e)}")
|
||||
|
||||
def plot_boxplot_with_seaborn(csv_file_path):
|
||||
"""
|
||||
使用seaborn绘制更美观的箱型图(可选方法)
|
||||
|
||||
参数:
|
||||
csv_file_path: CSV文件路径
|
||||
"""
|
||||
try:
|
||||
# 读取CSV文件
|
||||
df = pd.read_csv(csv_file_path)
|
||||
|
||||
# 获取第五列之后的数据列
|
||||
data_columns = df.iloc[:, 4:]
|
||||
|
||||
if data_columns.empty:
|
||||
print("错误:CSV文件中没有足够的列(至少需要5列)")
|
||||
return
|
||||
|
||||
# 将数据转换为长格式用于seaborn
|
||||
melted_data = pd.melt(data_columns, var_name='列名', value_name='数值')
|
||||
melted_data = melted_data.dropna() # 移除空值
|
||||
|
||||
# 设置图形大小
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# 使用seaborn绘制箱型图和散点
|
||||
sns.boxplot(data=melted_data, x='列名', y='数值', palette='Set3')
|
||||
sns.stripplot(data=melted_data, x='列名', y='数值',
|
||||
color='red', alpha=0.6, size=4, jitter=True)
|
||||
|
||||
# 设置标题和标签
|
||||
plt.title('数据列箱型图(Seaborn版本)', fontsize=16, fontweight='bold')
|
||||
plt.xlabel('数据列', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
|
||||
# 旋转x轴标签
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
|
||||
# 添加网格
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 显示图形
|
||||
plt.show()
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误:{str(e)}")
|
||||
|
||||
# 主程序
|
||||
if __name__ == "__main__":
|
||||
# 请修改为您的CSV文件路径
|
||||
csv_file_path = r"E:\code\WQ\yaobao925\output.csv" # 替换为您的CSV文件路径
|
||||
|
||||
print("请选择绘图方法:")
|
||||
print("1. 使用matplotlib绘制(所有列在一张图)")
|
||||
print("2. 使用seaborn绘制(所有列在一张图)")
|
||||
print("3. 分别绘制每列并保存(matplotlib版本)")
|
||||
print("4. 分别绘制每列并保存(seaborn版本)")
|
||||
|
||||
choice = input("请输入选择(1-4):").strip()
|
||||
|
||||
if choice == "1":
|
||||
plot_boxplot_with_scatter(csv_file_path)
|
||||
elif choice == "2":
|
||||
plot_boxplot_with_seaborn(csv_file_path)
|
||||
elif choice == "3":
|
||||
plot_individual_boxplots(csv_file_path)
|
||||
elif choice == "4":
|
||||
plot_individual_boxplots_seaborn(csv_file_path)
|
||||
else:
|
||||
print("默认使用分别绘制并保存(seaborn版本)...")
|
||||
plot_individual_boxplots_seaborn(csv_file_path)
|
||||
545
src/postprocessing/flight_path.py
Normal file
545
src/postprocessing/flight_path.py
Normal file
@ -0,0 +1,545 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
飞行轨迹可视化模块 - 在多架次GPS数据上绘制飞行轨迹
|
||||
|
||||
支持功能:
|
||||
1. 读取多个.gps文件(每个文件代表一个架次)
|
||||
2. 在高光谱假彩色影像上绘制飞行轨迹
|
||||
3. 不同架次使用不同颜色
|
||||
4. 图例显示架次起始到结束的时间段
|
||||
5. 添加指北针、比例尺
|
||||
6. 保存为PNG图像
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List, Dict, Union
|
||||
import warnings
|
||||
from matplotlib.patches import FancyArrowPatch
|
||||
import matplotlib.patheffects as path_effects
|
||||
from datetime import datetime
|
||||
from matplotlib.colors import ListedColormap
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
# 性能优化配置
|
||||
plt.rcParams['agg.path.chunksize'] = 10000
|
||||
plt.rcParams['path.simplify'] = True
|
||||
plt.rcParams['path.simplify_threshold'] = 0.1
|
||||
|
||||
# 导入GDAL用于影像读写
|
||||
try:
|
||||
from osgeo import gdal, osr
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
print("警告: GDAL未安装,地理坐标转换功能可能无法正常工作")
|
||||
|
||||
|
||||
class FlightPathVisualizer:
|
||||
"""飞行轨迹可视化类 - 在高光谱假彩色影像上绘制多架次飞行轨迹"""
|
||||
|
||||
# 预定义颜色方案(不同架次使用不同颜色)
|
||||
FLIGHT_COLORS = [
|
||||
'#FF0000', # 红色
|
||||
'#00FF00', # 绿色
|
||||
'#0000FF', # 蓝色
|
||||
'#FF00FF', # 紫色
|
||||
'#00FFFF', # 青色
|
||||
'#FFFF00', # 黄色
|
||||
'#FF8000', # 橙色
|
||||
'#8000FF', # 紫罗兰
|
||||
'#0080FF', # 天蓝
|
||||
'#FF0080', # 粉红
|
||||
]
|
||||
|
||||
def __init__(self, output_dir: str = "./flight_paths"):
|
||||
"""
|
||||
初始化飞行轨迹可视化器
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录,用于保存生成的轨迹图
|
||||
"""
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.rcParams['font.size'] = 12
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
def create_flight_path_map(self,
|
||||
gps_folder: str,
|
||||
hyperspectral_path: str,
|
||||
output_filename: Optional[str] = None,
|
||||
rgb_bands: Optional[List[int]] = None,
|
||||
line_width: int = 2,
|
||||
show_north_arrow: bool = True,
|
||||
show_scale_bar: bool = True,
|
||||
dpi: int = 300) -> str:
|
||||
"""
|
||||
创建飞行轨迹地图:在高光谱假彩色影像上绘制多架次飞行轨迹
|
||||
|
||||
Args:
|
||||
gps_folder: GPS文件夹路径,包含多个.gps文件
|
||||
hyperspectral_path: 高光谱影像文件路径 (.dat, .bsq, .tif等)
|
||||
output_filename: 输出文件名(如果为None则自动生成)
|
||||
rgb_bands: 用于RGB合成的三个波段索引 [R, G, B],默认为None自动选择(650,550,460nm)
|
||||
line_width: 轨迹线宽
|
||||
show_north_arrow: 是否显示指北针
|
||||
show_scale_bar: 是否显示比例尺
|
||||
dpi: 输出图像分辨率
|
||||
|
||||
Returns:
|
||||
生成的地图文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法处理地理坐标转换")
|
||||
|
||||
print(f"正在生成飞行轨迹地图...")
|
||||
|
||||
# 读取高光谱影像
|
||||
hyperspectral_img, geotransform, projection, width, height = self._read_hyperspectral(
|
||||
hyperspectral_path, rgb_bands)
|
||||
|
||||
# 读取所有GPS文件
|
||||
flight_data = self._read_gps_files(gps_folder)
|
||||
|
||||
if not flight_data:
|
||||
raise ValueError(f"未在 {gps_folder} 中找到有效的.gps文件")
|
||||
|
||||
# 将GPS坐标转换为像素坐标
|
||||
flight_pixels = self._convert_flights_to_pixels(
|
||||
flight_data, geotransform, width, height, projection)
|
||||
|
||||
# 创建地图
|
||||
if output_filename is None:
|
||||
folder_name = Path(gps_folder).name
|
||||
hs_name = Path(hyperspectral_path).stem
|
||||
output_filename = f"{hs_name}_{folder_name}_flight_paths.png"
|
||||
|
||||
output_path = self.output_dir / output_filename
|
||||
|
||||
self._create_map_visualization(
|
||||
hyperspectral_img, flight_pixels, flight_data,
|
||||
str(output_path), line_width,
|
||||
show_north_arrow, show_scale_bar, dpi,
|
||||
geotransform, width, height
|
||||
)
|
||||
|
||||
print(f"飞行轨迹地图已保存: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
def _read_hyperspectral(self, hyperspectral_path: str,
|
||||
rgb_bands: Optional[List[int]] = None) -> Tuple[np.ndarray, tuple, str, int, int]:
|
||||
"""读取高光谱影像 - 使用650/550/460nm波长"""
|
||||
dataset = gdal.Open(hyperspectral_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开高光谱影像: {hyperspectral_path}")
|
||||
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
|
||||
# 确定要读取的波段 - 使用指定波长 650, 550, 460nm
|
||||
if rgb_bands is None:
|
||||
if band_count >= 3:
|
||||
try:
|
||||
from src.utils.util import find_band_number
|
||||
rgb_bands = [
|
||||
find_band_number(650.0, hyperspectral_path), # Red ~650nm
|
||||
find_band_number(550.0, hyperspectral_path), # Green ~550nm
|
||||
find_band_number(460.0, hyperspectral_path) # Blue ~460nm
|
||||
]
|
||||
print(f" 根据波长选择RGB波段: R={rgb_bands[0]}, G={rgb_bands[1]}, B={rgb_bands[2]}")
|
||||
except Exception as e:
|
||||
print(f" 波长查找失败 ({e}),使用默认索引")
|
||||
rgb_bands = [min(band_count-1, int(band_count*0.25)),
|
||||
min(band_count-1, int(band_count*0.15)),
|
||||
min(band_count-1, int(band_count*0.05))]
|
||||
else:
|
||||
rgb_bands = [0, 0, 0]
|
||||
|
||||
# 读取RGB波段
|
||||
rgb_data = []
|
||||
for band_idx in rgb_bands:
|
||||
band = dataset.GetRasterBand(band_idx + 1)
|
||||
band_data = band.ReadAsArray().astype(np.float32)
|
||||
rgb_data.append(band_data)
|
||||
|
||||
# 堆叠为RGB图像
|
||||
if len(rgb_data) == 3:
|
||||
image_array = np.stack(rgb_data, axis=2)
|
||||
else:
|
||||
image_array = np.stack([rgb_data[0]]*3, axis=2)
|
||||
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
dataset = None
|
||||
|
||||
print(f" 读取影像: {width}x{height}x{image_array.shape[2]} (RGB)")
|
||||
if projection:
|
||||
proj_type = "投影坐标系" if "PROJCS" in projection else "地理坐标系"
|
||||
print(f" 影像投影: {proj_type}")
|
||||
|
||||
return image_array, geotransform, projection, width, height
|
||||
|
||||
def _read_gps_files(self, gps_folder: str) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
读取文件夹中的所有.gps文件
|
||||
|
||||
文件格式: 日期、时间、三个姿态角、经度、纬度、高程
|
||||
列: [date, time, pitch, roll, yaw, longitude, latitude, altitude]
|
||||
"""
|
||||
gps_folder_path = Path(gps_folder)
|
||||
if not gps_folder_path.exists():
|
||||
raise FileNotFoundError(f"GPS文件夹不存在: {gps_folder}")
|
||||
|
||||
gps_files = list(gps_folder_path.glob("*.gps"))
|
||||
if not gps_files:
|
||||
print(f"警告: 在 {gps_folder} 中未找到.gps文件")
|
||||
return {}
|
||||
|
||||
print(f"找到 {len(gps_files)} 个GPS文件")
|
||||
|
||||
flight_data = {}
|
||||
for gps_file in gps_files:
|
||||
try:
|
||||
# 读取GPS文件(制表符分隔)
|
||||
df = pd.read_csv(gps_file, sep='\t', header=None,
|
||||
names=['date', 'time', 'pitch', 'roll', 'yaw',
|
||||
'longitude', 'latitude', 'altitude'])
|
||||
|
||||
# 确保数值类型
|
||||
df['longitude'] = pd.to_numeric(df['longitude'], errors='coerce')
|
||||
df['latitude'] = pd.to_numeric(df['latitude'], errors='coerce')
|
||||
df['altitude'] = pd.to_numeric(df['altitude'], errors='coerce')
|
||||
|
||||
# 删除无效坐标
|
||||
df = df.dropna(subset=['longitude', 'latitude'])
|
||||
|
||||
if len(df) > 0:
|
||||
flight_data[gps_file.stem] = df
|
||||
print(f" ✓ 读取 {gps_file.name}: {len(df)} 个轨迹点")
|
||||
# 显示时间范围
|
||||
start_time = df.iloc[0]['time']
|
||||
end_time = df.iloc[-1]['time']
|
||||
print(f" 时间范围: {start_time} - {end_time}")
|
||||
else:
|
||||
print(f" ✗ {gps_file.name}: 无有效数据")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ 读取 {gps_file.name} 失败: {e}")
|
||||
|
||||
return flight_data
|
||||
|
||||
def _convert_flights_to_pixels(self, flight_data: Dict[str, pd.DataFrame],
|
||||
geotransform: tuple, width: int, height: int,
|
||||
projection: str = "") -> Dict[str, List[Tuple[float, float]]]:
|
||||
"""将所有飞行轨迹的地理坐标转换为像素坐标"""
|
||||
if geotransform is None:
|
||||
print("警告: 无地理变换信息,无法转换坐标")
|
||||
return {}
|
||||
|
||||
gt = geotransform
|
||||
flight_pixels = {}
|
||||
|
||||
# 检查是否需要投影转换
|
||||
needs_transform = projection and ("PROJCS" in projection or "GEOGCS" in projection)
|
||||
transform = None
|
||||
|
||||
if needs_transform and GDAL_AVAILABLE:
|
||||
try:
|
||||
src_srs = osr.SpatialReference()
|
||||
src_srs.ImportFromEPSG(4326) # WGS84
|
||||
dst_srs = osr.SpatialReference()
|
||||
dst_srs.ImportFromWkt(projection)
|
||||
transform = osr.CoordinateTransformation(src_srs, dst_srs)
|
||||
print(" ✓ 已创建WGS84到影像投影的坐标转换")
|
||||
except Exception as e:
|
||||
print(f" ⚠ 坐标转换创建失败: {e}")
|
||||
|
||||
for flight_name, df in flight_data.items():
|
||||
pixel_coords = []
|
||||
|
||||
for _, row in df.iterrows():
|
||||
lon = float(row['longitude'])
|
||||
lat = float(row['latitude'])
|
||||
|
||||
if transform is not None:
|
||||
try:
|
||||
proj_x, proj_y, _ = transform.TransformPoint(lat, lon)
|
||||
x = (proj_x - gt[0]) / gt[1]
|
||||
y = (proj_y - gt[3]) / gt[5]
|
||||
except:
|
||||
x = (lon - gt[0]) / gt[1]
|
||||
y = (lat - gt[3]) / gt[5]
|
||||
else:
|
||||
x = (lon - gt[0]) / gt[1]
|
||||
y = (lat - gt[3]) / gt[5]
|
||||
|
||||
# 限制在图像范围内
|
||||
x = max(0, min(x, width - 1))
|
||||
y = max(0, min(y, height - 1))
|
||||
|
||||
pixel_coords.append((x, y))
|
||||
|
||||
flight_pixels[flight_name] = pixel_coords
|
||||
|
||||
print(f" 已转换 {len(flight_pixels)} 个架次的坐标")
|
||||
return flight_pixels
|
||||
|
||||
def _create_false_color_image(self, image_array: np.ndarray) -> np.ndarray:
|
||||
"""创建假彩色RGB图像 - 应用线性拉伸"""
|
||||
if image_array.shape[2] != 3:
|
||||
if len(image_array.shape) == 2 or image_array.shape[2] == 1:
|
||||
if len(image_array.shape) == 2:
|
||||
image_array = np.stack([image_array]*3, axis=2)
|
||||
else:
|
||||
image_array = np.repeat(image_array, 3, axis=2)
|
||||
|
||||
def simple_linear_stretch(data, min_percent=1, max_percent=99):
|
||||
valid_data = data[np.isfinite(data)]
|
||||
if len(valid_data) == 0:
|
||||
return np.zeros_like(data, dtype=np.float32)
|
||||
|
||||
p_low = np.percentile(valid_data, min_percent)
|
||||
p_high = np.percentile(valid_data, max_percent)
|
||||
|
||||
if p_high - p_low < 1e-8:
|
||||
data_min = valid_data.min()
|
||||
data_max = valid_data.max()
|
||||
if data_max > data_min:
|
||||
stretched = (data - data_min) / (data_max - data_min)
|
||||
else:
|
||||
stretched = np.zeros_like(data, dtype=np.float32)
|
||||
else:
|
||||
stretched = (data - p_low) / (p_high - p_low)
|
||||
|
||||
stretched = np.clip(stretched, 0.0, 1.0)
|
||||
return stretched
|
||||
|
||||
r_stretched = simple_linear_stretch(image_array[:, :, 0])
|
||||
g_stretched = simple_linear_stretch(image_array[:, :, 1])
|
||||
b_stretched = simple_linear_stretch(image_array[:, :, 2])
|
||||
|
||||
rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2)
|
||||
rgb_image = np.nan_to_num(rgb_image, nan=0.0)
|
||||
rgb_image = np.clip(rgb_image, 0.0, 1.0)
|
||||
|
||||
# Gamma校正增加亮度
|
||||
gamma = 0.85
|
||||
rgb_image = np.power(rgb_image, gamma)
|
||||
|
||||
# 映射到0-255
|
||||
rgb_image = (rgb_image * 255).astype(np.uint8)
|
||||
|
||||
return rgb_image
|
||||
|
||||
def _create_map_visualization(self, image_array: np.ndarray,
|
||||
flight_pixels: Dict[str, List[Tuple[float, float]]],
|
||||
flight_data: Dict[str, pd.DataFrame],
|
||||
output_path: str,
|
||||
line_width: int,
|
||||
show_north_arrow: bool,
|
||||
show_scale_bar: bool,
|
||||
dpi: int,
|
||||
geotransform: tuple,
|
||||
width: int,
|
||||
height: int):
|
||||
"""创建地图可视化"""
|
||||
figsize = (14, 10)
|
||||
fig, ax = plt.subplots(figsize=figsize, dpi=150)
|
||||
|
||||
# 处理背景图像
|
||||
rgb_image = self._create_false_color_image(image_array)
|
||||
ax.imshow(rgb_image, interpolation='bilinear')
|
||||
|
||||
# 绘制飞行轨迹 - 不同架次不同颜色
|
||||
legend_elements = []
|
||||
|
||||
for idx, (flight_name, pixel_coords) in enumerate(flight_pixels.items()):
|
||||
if len(pixel_coords) < 2:
|
||||
continue
|
||||
|
||||
# 选择颜色
|
||||
color = self.FLIGHT_COLORS[idx % len(self.FLIGHT_COLORS)]
|
||||
|
||||
# 提取x,y坐标
|
||||
x_coords = [p[0] for p in pixel_coords]
|
||||
y_coords = [p[1] for p in pixel_coords]
|
||||
|
||||
# 绘制轨迹线
|
||||
ax.plot(x_coords, y_coords, color=color, linewidth=line_width,
|
||||
alpha=0.8, solid_capstyle='round')
|
||||
|
||||
# 标记起点和终点
|
||||
ax.plot(x_coords[0], y_coords[0], 'o', color=color, markersize=8,
|
||||
markeredgecolor='white', markeredgewidth=1)
|
||||
ax.plot(x_coords[-1], y_coords[-1], 's', color=color, markersize=8,
|
||||
markeredgecolor='white', markeredgewidth=1)
|
||||
|
||||
# 获取时间范围用于图例
|
||||
df = flight_data[flight_name]
|
||||
start_time = df.iloc[0]['time']
|
||||
end_time = df.iloc[-1]['time']
|
||||
|
||||
# 创建图例元素
|
||||
legend_label = f"{flight_name}: {start_time} - {end_time}"
|
||||
legend_elements.append(
|
||||
mpatches.Patch(color=color, label=legend_label)
|
||||
)
|
||||
|
||||
# 添加图例
|
||||
if legend_elements:
|
||||
ax.legend(handles=legend_elements, loc='lower right',
|
||||
frameon=True, facecolor='white', edgecolor='gray',
|
||||
fontsize=9, title='飞行轨迹 (起点→终点)')
|
||||
|
||||
# 添加指北针
|
||||
if show_north_arrow:
|
||||
self._add_north_arrow(ax, width, height)
|
||||
|
||||
# 添加比例尺
|
||||
if show_scale_bar and geotransform is not None:
|
||||
self._add_scale_bar(ax, geotransform, width, height)
|
||||
|
||||
# 设置标题
|
||||
ax.set_title('多架次飞行轨迹图', fontsize=16, fontweight='bold', pad=20)
|
||||
|
||||
|
||||
# 隐藏坐标轴刻度
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.2, linestyle='--')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=dpi, bbox_inches='tight', pad_inches=0.1, facecolor='white')
|
||||
plt.close(fig)
|
||||
|
||||
def _add_north_arrow(self, ax, width: int, height: int):
|
||||
"""添加指北针"""
|
||||
arrow_x = width * 0.92
|
||||
arrow_y = height * 0.88
|
||||
|
||||
arrow = FancyArrowPatch((arrow_x, arrow_y), (arrow_x, arrow_y - height*0.08),
|
||||
color='black', linewidth=3, arrowstyle='->', mutation_scale=20)
|
||||
ax.add_patch(arrow)
|
||||
|
||||
ax.text(arrow_x, arrow_y - height*0.1, 'N', fontsize=14, fontweight='bold',
|
||||
color='black', ha='center', va='center',
|
||||
path_effects=[path_effects.withStroke(linewidth=3, foreground='white')])
|
||||
|
||||
def _add_scale_bar(self, ax, geotransform: tuple, width: int, height: int):
|
||||
"""添加比例尺"""
|
||||
if geotransform is None:
|
||||
return
|
||||
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
image_width_meters = width * pixel_size_x
|
||||
scale_length_m = image_width_meters / 4
|
||||
|
||||
scale_options = [1000, 500, 200, 100, 50, 20, 10, 5, 2, 1]
|
||||
scale_meters = next((s for s in scale_options if s <= scale_length_m), 1)
|
||||
scale_pixels = int(scale_meters / pixel_size_x)
|
||||
|
||||
bar_x = width * 0.08
|
||||
bar_y = height * 0.92
|
||||
|
||||
ax.plot([bar_x, bar_x + scale_pixels], [bar_y, bar_y], color='black', linewidth=4)
|
||||
ax.plot([bar_x, bar_x], [bar_y, bar_y + 8], color='black', linewidth=2)
|
||||
ax.plot([bar_x + scale_pixels, bar_x + scale_pixels], [bar_y, bar_y + 8], color='black', linewidth=2)
|
||||
|
||||
ax.text(bar_x + scale_pixels/2, bar_y + 15, f'{scale_meters} m',
|
||||
fontsize=11, ha='center', va='bottom', fontweight='bold',
|
||||
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
|
||||
|
||||
def batch_create_maps(self, gps_folder: str,
|
||||
hyperspectral_folder: str,
|
||||
output_subdir: str = "flight_paths") -> Dict[str, str]:
|
||||
"""
|
||||
批量创建飞行轨迹地图
|
||||
|
||||
Args:
|
||||
gps_folder: 包含多个子文件夹(每个子文件夹包含.gps文件)的文件夹
|
||||
hyperspectral_folder: 包含高光谱影像的文件夹
|
||||
output_subdir: 输出子目录
|
||||
|
||||
Returns:
|
||||
生成的地图文件路径字典
|
||||
"""
|
||||
gps_folder_path = Path(gps_folder)
|
||||
hs_folder_path = Path(hyperspectral_folder)
|
||||
|
||||
if not gps_folder_path.exists():
|
||||
raise FileNotFoundError(f"GPS文件夹不存在: {gps_folder}")
|
||||
if not hs_folder_path.exists():
|
||||
raise FileNotFoundError(f"高光谱文件夹不存在: {hyperspectral_folder}")
|
||||
|
||||
output_dir = self.output_dir / output_subdir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
map_paths = {}
|
||||
|
||||
# 查找所有包含.gps文件的子文件夹
|
||||
gps_subfolders = [d for d in gps_folder_path.iterdir() if d.is_dir() and list(d.glob("*.gps"))]
|
||||
|
||||
# 查找高光谱影像
|
||||
hs_files = []
|
||||
for ext in ['*.dat', '*.bsq', '*.tif', '*.tiff']:
|
||||
hs_files.extend(list(hs_folder_path.glob(ext)))
|
||||
|
||||
if not hs_files:
|
||||
print(f"警告: 在 {hyperspectral_folder} 中未找到高光谱影像")
|
||||
return map_paths
|
||||
|
||||
print(f"找到 {len(gps_subfolders)} 个GPS子文件夹和 {len(hs_files)} 个高光谱影像")
|
||||
|
||||
# 简单匹配:使用第一个高光谱影像与所有GPS文件夹组合
|
||||
hs_file = hs_files[0]
|
||||
|
||||
for gps_subfolder in gps_subfolders:
|
||||
try:
|
||||
output_filename = f"{hs_file.stem}_{gps_subfolder.name}_flight_paths.png"
|
||||
map_path = self.create_flight_path_map(
|
||||
gps_folder=str(gps_subfolder),
|
||||
hyperspectral_path=str(hs_file),
|
||||
output_filename=output_filename,
|
||||
dpi=200
|
||||
)
|
||||
map_paths[gps_subfolder.name] = map_path
|
||||
print(f"✓ 生成: {gps_subfolder.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理 {gps_subfolder.name} 失败: {e}")
|
||||
|
||||
print(f"批量生成完成,共生成 {len(map_paths)} 个飞行轨迹图")
|
||||
return map_paths
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
print("FlightPathVisualizer类已创建")
|
||||
print("用法示例:")
|
||||
print(" visualizer = FlightPathVisualizer(output_dir='./flight_maps')")
|
||||
print(" map_path = visualizer.create_flight_path_map(")
|
||||
print(" gps_folder='./gps_data',")
|
||||
print(" hyperspectral_path='./hyperspectral.dat'")
|
||||
print(" )")
|
||||
|
||||
|
||||
visualizer = FlightPathVisualizer(output_dir=r"E:\code\WQ\pipeline_result\work_dir\9_visualization\flight_maps")
|
||||
# 生成飞行轨迹图
|
||||
map_path = visualizer.create_flight_path_map(
|
||||
gps_folder=r"D:\BaiduNetdiskDownload\20250902\gps", # GPS文件夹路径
|
||||
hyperspectral_path=r"E:\code\WQ\pipeline_result\work_dir\3_deglint\deglint_goodman.bsq", # 高光谱影像路径
|
||||
output_filename="flight_paths.png",
|
||||
line_width=2,
|
||||
dpi=300
|
||||
)
|
||||
2186
src/postprocessing/map.py
Normal file
2186
src/postprocessing/map.py
Normal file
File diff suppressed because it is too large
Load Diff
2561
src/postprocessing/map_beifeng.py
Normal file
2561
src/postprocessing/map_beifeng.py
Normal file
File diff suppressed because it is too large
Load Diff
184
src/postprocessing/plot_spectrum_by_parameter.py
Normal file
184
src/postprocessing/plot_spectrum_by_parameter.py
Normal file
@ -0,0 +1,184 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
from pathlib import Path
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
def load_and_plot_spectrum_by_parameters():
|
||||
"""
|
||||
加载数据并为每个水质参数绘制光谱曲线图
|
||||
"""
|
||||
try:
|
||||
# 数据文件路径
|
||||
data_file = Path(r"E:\code\WQ\yaobao925\spectral.csv")
|
||||
|
||||
if not data_file.exists():
|
||||
print(f"错误:数据文件不存在 - {data_file}")
|
||||
return
|
||||
|
||||
# 读取数据
|
||||
print("正在加载数据...")
|
||||
data = pd.read_csv(data_file)
|
||||
|
||||
print(f"数据形状: {data.shape}")
|
||||
print(f"列名: {list(data.columns[:15])}...") # 显示前15个列名
|
||||
|
||||
# 找到光谱数据的起始列(通常是数字列名)
|
||||
spectrum_start_idx = None
|
||||
for i, col in enumerate(data.columns):
|
||||
try:
|
||||
float(col)
|
||||
spectrum_start_idx = i
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if spectrum_start_idx is None:
|
||||
print("错误:未找到光谱数据列")
|
||||
return
|
||||
|
||||
print(f"光谱数据从第 {spectrum_start_idx + 1} 列开始")
|
||||
|
||||
# 分离水质参数和光谱数据
|
||||
water_quality_data = data.iloc[:, :spectrum_start_idx]
|
||||
spectrum_data = data.iloc[:, spectrum_start_idx:]
|
||||
|
||||
# 获取波长信息
|
||||
try:
|
||||
# 尝试直接转换为浮点数
|
||||
wavelengths = spectrum_data.columns.astype(float)
|
||||
except ValueError:
|
||||
# 如果包含字母,提取数字部分
|
||||
import re
|
||||
wavelengths = []
|
||||
for col in spectrum_data.columns:
|
||||
# 提取数字部分
|
||||
numbers = re.findall(r'\d+\.?\d*', str(col))
|
||||
if numbers:
|
||||
wavelengths.append(float(numbers[0]))
|
||||
else:
|
||||
# 如果没有数字,使用列索引
|
||||
wavelengths.append(float(len(wavelengths)))
|
||||
wavelengths = np.array(wavelengths)
|
||||
|
||||
print(f"波长范围: {wavelengths.min():.1f} - {wavelengths.max():.1f} nm")
|
||||
print(f"光谱数据形状: {spectrum_data.shape}")
|
||||
print(f"水质参数: {list(water_quality_data.columns)}")
|
||||
|
||||
# 过滤波长范围到374-1011nm
|
||||
wavelength_mask = (wavelengths >= 374) & (wavelengths <= 1011)
|
||||
filtered_wavelengths = wavelengths[wavelength_mask]
|
||||
filtered_spectrum_data = spectrum_data.iloc[:, wavelength_mask]
|
||||
|
||||
print(f"过滤后波长范围: {filtered_wavelengths.min():.1f} - {filtered_wavelengths.max():.1f} nm")
|
||||
print(f"过滤后光谱数据形状: {filtered_spectrum_data.shape}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(r'E:\code\WQ\yaobao925\plot')
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 为每个水质参数绘制光谱图
|
||||
for param_idx, parameter_name in enumerate(water_quality_data.columns):
|
||||
print(f"\n[{param_idx+1}/{len(water_quality_data.columns)}] 处理参数: {parameter_name}")
|
||||
|
||||
# 获取当前参数的数据
|
||||
parameter_values = water_quality_data[parameter_name]
|
||||
|
||||
# 过滤掉空值
|
||||
valid_mask = ~parameter_values.isna()
|
||||
if valid_mask.sum() == 0:
|
||||
print(f"参数 '{parameter_name}' 没有有效数据,跳过")
|
||||
continue
|
||||
|
||||
valid_param_values = parameter_values[valid_mask]
|
||||
valid_spectrum_data = filtered_spectrum_data[valid_mask]
|
||||
|
||||
print(f"有效样本数: {len(valid_param_values)}")
|
||||
|
||||
# 创建图形和轴
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# 归一化参数值到[0,1]范围,用于颜色映射
|
||||
param_min = valid_param_values.min()
|
||||
param_max = valid_param_values.max()
|
||||
|
||||
if param_max == param_min:
|
||||
# 如果所有值相同,使用中等颜色
|
||||
normalized_values = np.full(len(valid_param_values), 0.5)
|
||||
else:
|
||||
normalized_values = ((valid_param_values - param_min) / (param_max - param_min)).values
|
||||
|
||||
# 创建蓝红颜色映射(蓝色到红色)
|
||||
colormap = plt.cm.coolwarm # 蓝色(低值)到红色(高值)
|
||||
|
||||
# 绘制每条光谱曲线
|
||||
for i, (idx, spectrum) in enumerate(valid_spectrum_data.iterrows()):
|
||||
# 处理光谱数据中的空值
|
||||
spectrum_values = pd.Series(spectrum.values).fillna(0).values
|
||||
|
||||
# 根据参数值确定颜色
|
||||
color = colormap(normalized_values[i])
|
||||
alpha = 0.6 if len(valid_param_values) > 50 else 0.8 # 样本多时降低透明度
|
||||
|
||||
ax.plot(filtered_wavelengths, spectrum_values, color=color, alpha=alpha, linewidth=0.8)
|
||||
|
||||
# 设置图形属性
|
||||
ax.set_xlabel('波长 (nm)', fontsize=12)
|
||||
ax.set_ylabel('光谱强度', fontsize=12)
|
||||
ax.set_title(f'{parameter_name} 光谱曲线图\n参数范围: {param_min:.4f} - {param_max:.4f}',
|
||||
fontsize=14, fontweight='bold')
|
||||
|
||||
# 设置坐标轴范围,限制在374-1011nm
|
||||
ax.set_xlim(374, 1011)
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 创建颜色条
|
||||
sm = plt.cm.ScalarMappable(cmap=colormap,
|
||||
norm=plt.Normalize(vmin=param_min, vmax=param_max))
|
||||
sm.set_array([])
|
||||
cbar = plt.colorbar(sm, ax=ax, shrink=0.8)
|
||||
cbar.set_label(f'{parameter_name} 数值', rotation=270, labelpad=20, fontsize=12)
|
||||
|
||||
# 添加统计信息文本框
|
||||
stats_text = f'样本数: {len(valid_param_values)}\n'
|
||||
stats_text += f'均值: {valid_param_values.mean():.4f}\n'
|
||||
stats_text += f'标准差: {valid_param_values.std():.4f}'
|
||||
|
||||
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round',
|
||||
facecolor='wheat', alpha=0.8), fontsize=10)
|
||||
|
||||
# 优化布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
# 清理参数名称,用于文件名
|
||||
safe_param_name = "".join(c for c in parameter_name if c.isalnum() or c in ('-', '_', '.')).rstrip()
|
||||
output_file = output_dir / f"{safe_param_name}_spectrum.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
||||
plt.close() # 关闭图形释放内存
|
||||
|
||||
print(f"图片已保存到: {output_file}")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"所有光谱图绘制完成!")
|
||||
print(f"输出目录: {output_dir}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理过程中出现错误: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
load_and_plot_spectrum_by_parameters()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
636
src/postprocessing/point_map.py
Normal file
636
src/postprocessing/point_map.py
Normal file
@ -0,0 +1,636 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
采样点地图生成模块 - 在高光谱假彩色影像上标注采样点
|
||||
|
||||
支持功能:
|
||||
1. 读取高光谱影像并生成假彩色RGB图像
|
||||
2. 读取CSV文件中的采样点坐标(前两列为纬度、经度)
|
||||
3. 在影像上标注红色采样点
|
||||
4. 添加指北针、图例和比例尺
|
||||
5. 支持地理坐标系转换
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List, Dict, Union
|
||||
import warnings
|
||||
from matplotlib.patches import FancyArrowPatch
|
||||
import matplotlib.patheffects as path_effects
|
||||
|
||||
# 性能优化配置
|
||||
plt.rcParams['agg.path.chunksize'] = 10000 # 提高矢量渲染性能
|
||||
plt.rcParams['path.simplify'] = True
|
||||
plt.rcParams['path.simplify_threshold'] = 0.1
|
||||
|
||||
# 导入GDAL用于影像读写
|
||||
try:
|
||||
from osgeo import gdal, osr
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
print("警告: GDAL未安装,地理坐标转换功能可能无法正常工作")
|
||||
|
||||
|
||||
class SamplingPointMap:
|
||||
"""采样点地图生成类 - 在高光谱假彩色影像上标注采样点"""
|
||||
|
||||
def __init__(self, output_dir: str = "./point_maps", fast_mode: bool = False):
|
||||
"""
|
||||
初始化采样点地图生成器
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录,用于保存生成的地图
|
||||
fast_mode: 是否启用快速模式(降低质量换取速度)
|
||||
"""
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.fast_mode = fast_mode
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.rcParams['font.size'] = 12
|
||||
|
||||
# 性能优化设置
|
||||
if fast_mode:
|
||||
plt.rcParams['figure.dpi'] = 150
|
||||
plt.rcParams['savefig.dpi'] = 150
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
else:
|
||||
plt.rcParams['figure.dpi'] = 300
|
||||
plt.rcParams['savefig.dpi'] = 300
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
def create_sampling_point_map(self,
|
||||
hyperspectral_path: str,
|
||||
csv_path: str,
|
||||
output_filename: Optional[str] = None,
|
||||
rgb_bands: Optional[List[int]] = None,
|
||||
point_color: str = 'red',
|
||||
point_size: int = 80,
|
||||
point_alpha: float = 0.8,
|
||||
show_north_arrow: bool = True,
|
||||
show_scale_bar: bool = True,
|
||||
show_legend: bool = True,
|
||||
dpi: int = None,
|
||||
downsample: bool = False) -> str:
|
||||
"""
|
||||
创建采样点地图:在高光谱假彩色影像上标注采样点
|
||||
|
||||
Args:
|
||||
hyperspectral_path: 高光谱影像文件路径 (.dat, .bsq, .tif等)
|
||||
csv_path: 采样点CSV文件路径(前两列为纬度、经度)
|
||||
output_filename: 输出文件名(如果为None则自动生成)
|
||||
rgb_bands: 用于RGB合成的三个波段索引 [R, G, B],默认为None自动选择
|
||||
point_color: 采样点颜色
|
||||
point_size: 采样点大小
|
||||
point_alpha: 采样点透明度
|
||||
show_north_arrow: 是否显示指北针
|
||||
show_scale_bar: 是否显示比例尺
|
||||
show_legend: 是否显示图例
|
||||
dpi: 输出图像分辨率(None时使用fast_mode设置)
|
||||
downsample: 是否对图像进行下采样以加快速度(大影像推荐启用)
|
||||
|
||||
Returns:
|
||||
生成的地图文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法处理地理坐标转换")
|
||||
|
||||
print(f"正在生成采样点地图...{' (快速模式)' if self.fast_mode else ''}")
|
||||
|
||||
# 读取高光谱影像 - 优化:仅读取需要的RGB波段
|
||||
hyperspectral_img, geotransform, projection, width, height, sample_factor = self._read_hyperspectral(
|
||||
hyperspectral_path, rgb_bands, downsample)
|
||||
|
||||
# 读取采样点
|
||||
sampling_points = self._read_sampling_points(csv_path)
|
||||
|
||||
# 生成假彩色图像 - 应用线性拉伸
|
||||
rgb_image = self._create_false_color_image(hyperspectral_img)
|
||||
|
||||
# 将地理坐标转换为像素坐标 - 支持投影系转换和下采样
|
||||
pixel_coords = self._geo_to_pixel(sampling_points, geotransform, width, height, projection, sample_factor)
|
||||
|
||||
# 创建地图
|
||||
if output_filename is None:
|
||||
csv_name = Path(csv_path).stem
|
||||
hs_name = Path(hyperspectral_path).stem
|
||||
output_filename = f"{hs_name}_{csv_name}_sampling_map.png"
|
||||
|
||||
output_path = self.output_dir / output_filename
|
||||
|
||||
# 使用更优化的绘图设置
|
||||
if dpi is None:
|
||||
dpi = 150 if self.fast_mode else 200
|
||||
|
||||
self._create_map_visualization(
|
||||
rgb_image, pixel_coords, sampling_points,
|
||||
str(output_path), point_color, point_size, point_alpha,
|
||||
show_north_arrow, show_scale_bar, show_legend, dpi,
|
||||
geotransform, width, height, downsample, projection, sample_factor
|
||||
)
|
||||
|
||||
print(f"采样点地图已保存: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
def _read_hyperspectral(self, hyperspectral_path: str,
|
||||
rgb_bands: Optional[List[int]] = None,
|
||||
downsample: bool = False) -> Tuple[np.ndarray, tuple, str, int, int]:
|
||||
"""优化版:读取高光谱影像 - 仅读取需要的RGB波段"""
|
||||
dataset = gdal.Open(hyperspectral_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开高光谱影像: {hyperspectral_path}")
|
||||
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
|
||||
# 确定要读取的波段 - 优先使用指定波长 (650nm, 550nm, 460nm)
|
||||
if rgb_bands is None:
|
||||
if band_count >= 3:
|
||||
try:
|
||||
# 使用find_band_number根据波长查找RGB波段
|
||||
from src.utils.util import find_band_number
|
||||
rgb_bands = [
|
||||
find_band_number(650.0, hyperspectral_path), # Red ~650nm
|
||||
find_band_number(550.0, hyperspectral_path), # Green ~550nm
|
||||
find_band_number(460.0, hyperspectral_path) # Blue ~460nm
|
||||
]
|
||||
print(f" 根据波长选择RGB波段: R={rgb_bands[0]}, G={rgb_bands[1]}, B={rgb_bands[2]}")
|
||||
except Exception as e:
|
||||
print(f" 波长查找失败 ({e}),使用默认索引")
|
||||
# 回退到基于索引的选择
|
||||
rgb_bands = [min(band_count-1, int(band_count*0.25)),
|
||||
min(band_count-1, int(band_count*0.15)),
|
||||
min(band_count-1, int(band_count*0.05))]
|
||||
else:
|
||||
rgb_bands = [0, 0, 0]
|
||||
|
||||
# 下采样控制 - 用户反馈下采样读取会导致像素值全为0
|
||||
if downsample and (width > 2000 or height > 2000):
|
||||
print(f" ⚠ 下采样暂被禁用(会导致像素值全0),使用原始分辨率: {width}x{height}")
|
||||
sample_factor = 1
|
||||
target_width = width
|
||||
target_height = height
|
||||
else:
|
||||
sample_factor = 1
|
||||
target_width = width
|
||||
target_height = height
|
||||
|
||||
# 只读取需要的RGB波段(性能关键优化)
|
||||
rgb_data = []
|
||||
for band_idx in rgb_bands:
|
||||
band = dataset.GetRasterBand(band_idx + 1)
|
||||
# 直接使用完整分辨率读取,避免下采样导致像素值为0的问题
|
||||
band_data = band.ReadAsArray().astype(np.float32)
|
||||
rgb_data.append(band_data)
|
||||
|
||||
# 堆叠为RGB图像 (height, width, 3)
|
||||
if len(rgb_data) == 3:
|
||||
image_array = np.stack(rgb_data, axis=2)
|
||||
else:
|
||||
# 如果只有1个波段,复制为RGB
|
||||
image_array = np.stack([rgb_data[0]]*3, axis=2)
|
||||
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
# 释放数据集
|
||||
dataset = None
|
||||
|
||||
# 更新尺寸信息
|
||||
final_width = target_width if sample_factor > 1 else width
|
||||
final_height = target_height if sample_factor > 1 else height
|
||||
|
||||
print(f" 读取影像: {final_width}x{final_height}x{image_array.shape[2]} (RGB)")
|
||||
if projection:
|
||||
proj_type = "投影坐标系" if "PROJCS" in projection else "地理坐标系"
|
||||
print(f" 影像投影: {proj_type}")
|
||||
if sample_factor > 1:
|
||||
print(f" 下采样因子: {sample_factor}")
|
||||
|
||||
return image_array, geotransform, projection, final_width, final_height, sample_factor
|
||||
|
||||
def _read_sampling_points(self, csv_path: str) -> pd.DataFrame:
|
||||
"""读取采样点CSV文件"""
|
||||
if not Path(csv_path).exists():
|
||||
raise FileNotFoundError(f"CSV文件不存在: {csv_path}")
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# 检查前两列是否为纬度和经度
|
||||
if len(df.columns) < 2:
|
||||
raise ValueError("CSV文件至少需要两列(纬度、经度)")
|
||||
|
||||
# 假设前两列是纬度和经度
|
||||
lat_col = df.columns[0]
|
||||
lon_col = df.columns[1]
|
||||
|
||||
# 重命名列
|
||||
df = df.rename(columns={lat_col: 'latitude', lon_col: 'longitude'})
|
||||
|
||||
# 确保数值类型
|
||||
df['latitude'] = pd.to_numeric(df['latitude'], errors='coerce')
|
||||
df['longitude'] = pd.to_numeric(df['longitude'], errors='coerce')
|
||||
|
||||
# 删除无效的坐标
|
||||
df = df.dropna(subset=['latitude', 'longitude'])
|
||||
|
||||
print(f"读取到 {len(df)} 个采样点")
|
||||
return df
|
||||
|
||||
def _create_false_color_image(self, image_array: np.ndarray,
|
||||
rgb_bands: Optional[List[int]] = None) -> np.ndarray:
|
||||
"""创建假彩色RGB图像 - 应用线性拉伸和Gamma校正"""
|
||||
# 由于_read_hyperspectral已返回RGB图像,这里仅进行最终处理
|
||||
if image_array.shape[2] != 3:
|
||||
# 确保是3通道
|
||||
if len(image_array.shape) == 2 or image_array.shape[2] == 1:
|
||||
if len(image_array.shape) == 2:
|
||||
image_array = np.stack([image_array]*3, axis=2)
|
||||
else:
|
||||
image_array = np.repeat(image_array, 3, axis=2)
|
||||
|
||||
print(f" 处理前图像范围: R[{image_array[:,:,0].min():.3f}-{image_array[:,:,0].max():.3f}], "
|
||||
f"G[{image_array[:,:,1].min():.3f}-{image_array[:,:,1].max():.3f}], "
|
||||
f"B[{image_array[:,:,2].min():.3f}-{image_array[:,:,2].max():.3f}]")
|
||||
|
||||
# 增强型线性拉伸 - 解决图像太暗的问题
|
||||
def simple_linear_stretch(data, min_percent=1, max_percent=99):
|
||||
"""增强对比度的线性拉伸"""
|
||||
valid_data = data[np.isfinite(data)]
|
||||
if len(valid_data) == 0:
|
||||
return np.zeros_like(data, dtype=np.float32)
|
||||
|
||||
# 计算百分位数,使用更激进的拉伸 (1%-99%)
|
||||
p_low = np.percentile(valid_data, min_percent)
|
||||
p_high = np.percentile(valid_data, max_percent)
|
||||
|
||||
if p_high - p_low < 1e-8:
|
||||
# 如果数据范围太小,使用最小最大值归一化
|
||||
data_min = valid_data.min()
|
||||
data_max = valid_data.max()
|
||||
if data_max > data_min:
|
||||
stretched = (data - data_min) / (data_max - data_min)
|
||||
else:
|
||||
stretched = np.zeros_like(data, dtype=np.float32)
|
||||
else:
|
||||
stretched = (data - p_low) / (p_high - p_low)
|
||||
|
||||
# 允许轻微过饱和以增加对比度
|
||||
stretched = np.clip(stretched, 0.0, 1.05)
|
||||
stretched = np.clip(stretched, 0.0, 1.0) # 最终确保在[0,1]
|
||||
return stretched
|
||||
|
||||
# 对每个通道进行拉伸
|
||||
r_stretched = simple_linear_stretch(image_array[:, :, 0])
|
||||
g_stretched = simple_linear_stretch(image_array[:, :, 1])
|
||||
b_stretched = simple_linear_stretch(image_array[:, :, 2])
|
||||
|
||||
# 合成为RGB图像
|
||||
rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2)
|
||||
rgb_image = np.nan_to_num(rgb_image, nan=0.0)
|
||||
|
||||
# 最终确保范围在[0,1],并轻微增强对比度
|
||||
rgb_image = np.clip(rgb_image, 0.0, 1.0)
|
||||
|
||||
# 可选:Gamma校正增加亮度(解决太暗问题)
|
||||
gamma = 1 # <1会增加亮度
|
||||
rgb_image = np.power(rgb_image, gamma)
|
||||
|
||||
# 映射到0-255范围(uint8),这样imshow显示效果更好
|
||||
rgb_image = (rgb_image * 255).astype(np.uint8)
|
||||
|
||||
print(f" 处理后图像范围: [0-255] (Gamma={gamma})")
|
||||
|
||||
return rgb_image
|
||||
|
||||
def _geo_to_pixel(self, sampling_points: pd.DataFrame,
|
||||
geotransform: tuple, width: int, height: int,
|
||||
projection: str = "", sample_factor: int = 1) -> List[Tuple[float, float]]:
|
||||
"""
|
||||
使用GDAL进行地理坐标到像素坐标的投影变换 - 支持下采样
|
||||
|
||||
原始点位坐标格式: 41.66054612 124.2208338 (WGS84地理坐标: 纬度,经度)
|
||||
高光谱影像通常使用UTM或其他投影坐标系
|
||||
当图像下采样时,sample_factor > 1,需要相应缩放坐标
|
||||
"""
|
||||
if geotransform is None or len(sampling_points) == 0:
|
||||
# 如果没有地理变换信息,使用图像中心
|
||||
return [(width/2, height/2) for _ in range(len(sampling_points))]
|
||||
|
||||
pixel_coords = []
|
||||
gt = geotransform
|
||||
|
||||
# 检查是否需要投影转换
|
||||
needs_transform = False
|
||||
if projection and ("PROJCS" in projection or "GEOGCS" in projection):
|
||||
needs_transform = True
|
||||
print(f" 检测到影像投影: {projection[:80]}...")
|
||||
|
||||
# 创建坐标转换对象(WGS84 -> 影像投影)
|
||||
transform = None
|
||||
if needs_transform and GDAL_AVAILABLE:
|
||||
try:
|
||||
# 源坐标系: WGS84 (EPSG:4326)
|
||||
src_srs = osr.SpatialReference()
|
||||
src_srs.ImportFromEPSG(4326) # WGS84
|
||||
|
||||
# 目标坐标系: 影像的投影
|
||||
dst_srs = osr.SpatialReference()
|
||||
dst_srs.ImportFromWkt(projection)
|
||||
|
||||
# 创建坐标转换
|
||||
transform = osr.CoordinateTransformation(src_srs, dst_srs)
|
||||
print(" ✓ 已创建WGS84到影像投影的坐标转换")
|
||||
except Exception as e:
|
||||
print(f" ⚠ 坐标转换创建失败: {e},使用简化变换")
|
||||
transform = None
|
||||
|
||||
for _, row in sampling_points.iterrows():
|
||||
lon = float(row['longitude']) # 经度 (WGS84)
|
||||
lat = float(row['latitude']) # 纬度 (WGS84)
|
||||
|
||||
if transform is not None:
|
||||
# 使用GDAL进行投影转换: (经度, 纬度) -> (投影X, 投影Y)
|
||||
try:
|
||||
proj_x, proj_y, _ = transform.TransformPoint(lat, lon)
|
||||
# 再转换为像素坐标
|
||||
x = (proj_x - gt[0]) / gt[1]
|
||||
y = (proj_y - gt[3]) / gt[5]
|
||||
except Exception as e:
|
||||
# 转换失败时回退到直接计算
|
||||
x = (lon - gt[0]) / gt[1]
|
||||
y = (lat - gt[3]) / gt[5]
|
||||
else:
|
||||
# 直接使用仿射变换(坐标系一致的情况)
|
||||
x = (lon - gt[0]) / gt[1]
|
||||
y = (lat - gt[3]) / gt[5]
|
||||
|
||||
# 如果图像进行了下采样,需要相应缩放坐标
|
||||
if sample_factor > 1:
|
||||
x = x / sample_factor
|
||||
y = y / sample_factor
|
||||
|
||||
# 限制在图像范围内(使用下采样后的尺寸)
|
||||
x = max(0, min(x, width - 1))
|
||||
y = max(0, min(y, height - 1))
|
||||
|
||||
pixel_coords.append((x, y))
|
||||
|
||||
if transform is not None:
|
||||
print(f" ✓ 使用GDAL投影变换处理 {len(pixel_coords)} 个采样点")
|
||||
else:
|
||||
print(f" 使用直接仿射变换处理 {len(pixel_coords)} 个采样点")
|
||||
|
||||
return pixel_coords
|
||||
|
||||
def _create_map_visualization(self, rgb_image: np.ndarray,
|
||||
pixel_coords: List[Tuple[float, float]],
|
||||
sampling_points: pd.DataFrame,
|
||||
output_path: str,
|
||||
point_color: str,
|
||||
point_size: int,
|
||||
point_alpha: float,
|
||||
show_north_arrow: bool,
|
||||
show_scale_bar: bool,
|
||||
show_legend: bool,
|
||||
dpi: int,
|
||||
geotransform: tuple,
|
||||
width: int,
|
||||
height: int,
|
||||
downsample: bool = False,
|
||||
projection: str = "",
|
||||
sample_factor: int = 1):
|
||||
"""创建地图可视化 - 优化版"""
|
||||
# 使用更小的figure尺寸加快渲染
|
||||
figsize = (10, 8) if self.fast_mode or downsample else (12, 10)
|
||||
fig, ax = plt.subplots(figsize=figsize, dpi=100 if self.fast_mode else 150)
|
||||
|
||||
# 显示假彩色图像 - 现在已经是0-255的uint8格式
|
||||
print(f" 最终图像数据范围: [{rgb_image.min()}, {rgb_image.max()}] (uint8)")
|
||||
ax.imshow(rgb_image, interpolation='nearest' if self.fast_mode else 'bilinear')
|
||||
|
||||
# 绘制采样点 - 优化:使用scatter代替循环plot
|
||||
if pixel_coords:
|
||||
x_coords = [p[0] for p in pixel_coords]
|
||||
y_coords = [p[1] for p in pixel_coords]
|
||||
ax.scatter(x_coords, y_coords, c=point_color, s=point_size,
|
||||
alpha=point_alpha, edgecolors='white', linewidth=1.5)
|
||||
|
||||
# 添加指北针
|
||||
if show_north_arrow:
|
||||
self._add_north_arrow(ax, width, height, position='bottom-left', direction='down')
|
||||
|
||||
# 添加比例尺
|
||||
if show_scale_bar and geotransform is not None:
|
||||
self._add_scale_bar(ax, geotransform, width, height)
|
||||
|
||||
# 添加图例
|
||||
if show_legend:
|
||||
legend_text = f'采样点 (n={len(sampling_points)})'
|
||||
ax.plot([], [], 'o', color=point_color, markersize=8, label=legend_text)
|
||||
ax.legend(loc='lower right', frameon=True, facecolor='white', edgecolor='gray')
|
||||
|
||||
# 设置标题和标签
|
||||
ax.set_title('高光谱影像采样点分布图', fontsize=16, fontweight='bold', pad=20)
|
||||
|
||||
|
||||
# 隐藏坐标轴刻度
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
# 添加网格
|
||||
ax.grid(True, alpha=0.2, linestyle='--')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存参数 - 避免传递不兼容的参数
|
||||
save_kwargs = {
|
||||
'dpi': dpi,
|
||||
'bbox_inches': 'tight',
|
||||
'pad_inches': 0.05,
|
||||
'facecolor': 'white'
|
||||
}
|
||||
|
||||
# 仅添加matplotlib支持的参数
|
||||
if self.fast_mode:
|
||||
save_kwargs['dpi'] = min(dpi, 180) # 快速模式降低DPI
|
||||
|
||||
plt.savefig(output_path, **save_kwargs)
|
||||
plt.close(fig)
|
||||
|
||||
def _add_north_arrow(self, ax, width: int, height: int, position='top-right', direction='down',
|
||||
size=0.08, color='white', n_color='white', outline_color='black'):
|
||||
"""
|
||||
添加指北针,可配置位置、方向、大小、颜色。
|
||||
|
||||
参数:
|
||||
ax: matplotlib Axes对象
|
||||
width, height: 图像宽高(用于相对定位)
|
||||
position: 'top-left', 'top-right', 'bottom-left', 'bottom-right'
|
||||
direction: 'up', 'down', 'left', 'right' 箭头指向
|
||||
size: 箭头长度相对于高度的比例(0.05~0.12)
|
||||
color: 箭头颜色
|
||||
n_color: 'N' 文字颜色
|
||||
outline_color: 文字描边颜色
|
||||
"""
|
||||
# 位置映射(偏移系数)
|
||||
pos_map = {
|
||||
'top-left': (0.08, 0.88),
|
||||
'top-right': (0.92, 0.88),
|
||||
'bottom-left': (0.08, 0.12),
|
||||
'bottom-right': (0.92, 0.12),
|
||||
}
|
||||
arrow_x_ratio, arrow_y_ratio = pos_map.get(position, (0.92, 0.88))
|
||||
arrow_x = width * arrow_x_ratio
|
||||
arrow_y = height * arrow_y_ratio
|
||||
|
||||
# 方向映射(箭头终点偏移)
|
||||
direction_map = {
|
||||
'up': (0, +size),
|
||||
'down': (0, -size),
|
||||
'left': (-size, 0),
|
||||
'right': (+size, 0),
|
||||
}
|
||||
dx, dy = direction_map.get(direction, (0, -size))
|
||||
end_x = arrow_x + dx * width # 注意:dx是比例,乘以宽度/高度保持比例一致
|
||||
end_y = arrow_y + dy * height
|
||||
|
||||
# 箭头绘制
|
||||
arrow = FancyArrowPatch((arrow_x, arrow_y), (end_x, end_y),
|
||||
color=color, linewidth=3,
|
||||
arrowstyle='->', mutation_scale=20)
|
||||
ax.add_patch(arrow)
|
||||
|
||||
# N 文字位置:在箭头尾部或头部?通常放在箭头指向的反方向末端
|
||||
# 这里放在箭头尾部向外偏移一点(便于阅读)
|
||||
# 偏移系数根据方向决定
|
||||
offset_scale = 0.02 # 偏移量比例
|
||||
if direction == 'up':
|
||||
text_x = arrow_x
|
||||
text_y = arrow_y - height * offset_scale # 放在箭头下方
|
||||
elif direction == 'down':
|
||||
text_x = arrow_x
|
||||
text_y = arrow_y + height * offset_scale # 放在箭头上方
|
||||
elif direction == 'left':
|
||||
text_x = arrow_x + width * offset_scale
|
||||
text_y = arrow_y
|
||||
else: # right
|
||||
text_x = arrow_x - width * offset_scale
|
||||
text_y = arrow_y
|
||||
|
||||
ax.text(text_x, text_y, 'N', fontsize=14, fontweight='bold',
|
||||
color=n_color, ha='center', va='center',
|
||||
path_effects=[path_effects.withStroke(linewidth=3, foreground=outline_color)])
|
||||
|
||||
def _add_scale_bar(self, ax, geotransform: tuple, width: int, height: int):
|
||||
"""添加比例尺"""
|
||||
if geotransform is None:
|
||||
return
|
||||
|
||||
# 计算图像实际宽度(米)
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
image_width_meters = width * pixel_size_x
|
||||
|
||||
# 选择合适的比例尺长度(图像宽度的1/4)
|
||||
scale_length_m = image_width_meters / 4
|
||||
scale_length_pixels = width / 4
|
||||
|
||||
# 找到合适的刻度
|
||||
scale_options = [1000, 500, 200, 100, 50, 20, 10, 5, 2, 1]
|
||||
scale_meters = next((s for s in scale_options if s <= scale_length_m), 1)
|
||||
|
||||
scale_pixels = int(scale_meters / pixel_size_x)
|
||||
|
||||
# 在左下角添加比例尺
|
||||
bar_x = width * 0.08
|
||||
bar_y = height * 0.92
|
||||
|
||||
# 绘制比例尺线
|
||||
ax.plot([bar_x, bar_x + scale_pixels], [bar_y, bar_y], color='white', linewidth=4)
|
||||
|
||||
# 添加刻度线
|
||||
ax.plot([bar_x, bar_x], [bar_y, bar_y + 8], color='white', linewidth=2)
|
||||
ax.plot([bar_x + scale_pixels, bar_x + scale_pixels], [bar_y, bar_y + 8], color='white', linewidth=2)
|
||||
|
||||
# 添加文字
|
||||
ax.text(bar_x + scale_pixels/2, bar_y , f'{scale_meters} m',
|
||||
fontsize=11, ha='center', va='bottom', fontweight='bold',
|
||||
bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
|
||||
|
||||
def batch_create_maps(self, hyperspectral_path: str,
|
||||
csv_folder: str,
|
||||
output_subdir: str = "sampling_maps",
|
||||
fast_mode: bool = True) -> Dict[str, str]:
|
||||
"""
|
||||
批量创建采样点地图
|
||||
|
||||
Args:
|
||||
hyperspectral_path: 高光谱影像路径
|
||||
csv_folder: 包含多个CSV文件的文件夹
|
||||
output_subdir: 输出子目录
|
||||
|
||||
Returns:
|
||||
生成的地图文件路径字典
|
||||
"""
|
||||
csv_folder_path = Path(csv_folder)
|
||||
if not csv_folder_path.exists():
|
||||
raise FileNotFoundError(f"CSV文件夹不存在: {csv_folder}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = self.output_dir / output_subdir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
map_paths = {}
|
||||
|
||||
# 查找所有CSV文件
|
||||
csv_files = list(csv_folder_path.glob("*.csv"))
|
||||
|
||||
print(f"找到 {len(csv_files)} 个CSV文件,开始批量生成采样点地图... (快速模式: {fast_mode})")
|
||||
|
||||
for csv_file in csv_files:
|
||||
try:
|
||||
output_filename = f"{Path(hyperspectral_path).stem}_{csv_file.stem}_sampling_map.png"
|
||||
map_path = self.create_sampling_point_map(
|
||||
hyperspectral_path=hyperspectral_path,
|
||||
csv_path=str(csv_file),
|
||||
output_filename=output_filename,
|
||||
downsample=True, # 批量模式默认下采样
|
||||
dpi=120 if fast_mode else 200
|
||||
)
|
||||
map_paths[csv_file.name] = map_path
|
||||
print(f"✓ 生成: {csv_file.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理 {csv_file.name} 失败: {e}")
|
||||
|
||||
print(f"批量生成完成,共生成 {len(map_paths)} 个采样点地图")
|
||||
return map_paths
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
map_generator = SamplingPointMap(output_dir="./point_maps")
|
||||
|
||||
# 测试代码已禁用,避免直接运行时出错
|
||||
map_generator_fast = SamplingPointMap(output_dir="./point_maps", fast_mode=True)
|
||||
map_path = map_generator_fast.create_sampling_point_map(
|
||||
hyperspectral_path=r"D:\BaiduNetdiskDownload\yaobao\result3.bsq",
|
||||
csv_path=r"E:\code\WQ\pipeline_result\work_dir\4_processed_data\processed_data.csv",
|
||||
downsample=True,
|
||||
dpi=150
|
||||
)
|
||||
print("测试代码已注释,请通过GUI或手动调用使用。")
|
||||
|
||||
print("SamplingPointMap类已创建,可以用于生成带采样点的地图。")
|
||||
print("性能优化功能:")
|
||||
print(" - fast_mode=True: 快速模式 (推荐用于预览)")
|
||||
print(" - downsample=True: 对大影像下采样 (推荐用于>2000x2000影像)")
|
||||
print(" - 使用: SamplingPointMap(fast_mode=True).create_sampling_point_map(...)")
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
1578
src/postprocessing/report_word.py
Normal file
1578
src/postprocessing/report_word.py
Normal file
File diff suppressed because it is too large
Load Diff
BIN
src/postprocessing/reports/水质参数反演分析报告_20260330_164433.docx
Normal file
BIN
src/postprocessing/reports/水质参数反演分析报告_20260330_164433.docx
Normal file
Binary file not shown.
BIN
src/postprocessing/reports/水质参数反演分析报告_20260330_164511.docx
Normal file
BIN
src/postprocessing/reports/水质参数反演分析报告_20260330_164511.docx
Normal file
Binary file not shown.
BIN
src/postprocessing/reports/水质参数反演分析报告_20260330_165245.docx
Normal file
BIN
src/postprocessing/reports/水质参数反演分析报告_20260330_165245.docx
Normal file
Binary file not shown.
BIN
src/postprocessing/reports/水质参数反演分析报告_20260330_165536.docx
Normal file
BIN
src/postprocessing/reports/水质参数反演分析报告_20260330_165536.docx
Normal file
Binary file not shown.
BIN
src/postprocessing/reports/水质参数反演分析报告_20260331_154250.docx
Normal file
BIN
src/postprocessing/reports/水质参数反演分析报告_20260331_154250.docx
Normal file
Binary file not shown.
BIN
src/postprocessing/reports/水质参数反演分析报告_20260331_155128.docx
Normal file
BIN
src/postprocessing/reports/水质参数反演分析报告_20260331_155128.docx
Normal file
Binary file not shown.
BIN
src/postprocessing/reports/水质参数反演分析报告_20260331_155142.docx
Normal file
BIN
src/postprocessing/reports/水质参数反演分析报告_20260331_155142.docx
Normal file
Binary file not shown.
1185
src/postprocessing/visualization_reports.py
Normal file
1185
src/postprocessing/visualization_reports.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user