815 lines
31 KiB
Python
815 lines
31 KiB
Python
|
||
import numpy as np
|
||
from scipy import stats
|
||
from scipy.spatial.distance import pdist, squareform
|
||
from sklearn.cluster import KMeans
|
||
from sklearn.preprocessing import StandardScaler
|
||
import spectral as spy
|
||
from spectral import envi
|
||
import matplotlib.pyplot as plt
|
||
import os
|
||
from typing import Tuple, List, Dict, Optional, Union
|
||
import warnings
|
||
import xml.etree.ElementTree as ET
|
||
from shapely.geometry import Polygon, Point
|
||
import warnings
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
class HyperspectralDistanceMetrics:
|
||
"""高光谱数据距离度量类"""
|
||
|
||
@staticmethod
|
||
def cosine_similarity(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||
"""余弦相似度"""
|
||
dot_product = np.dot(spectrum1, spectrum2)
|
||
norm1 = np.linalg.norm(spectrum1)
|
||
norm2 = np.linalg.norm(spectrum2)
|
||
return dot_product / (norm1 * norm2) if norm1 != 0 and norm2 != 0 else 0
|
||
|
||
@staticmethod
|
||
def euclidean_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||
"""欧氏距离"""
|
||
return np.linalg.norm(spectrum1 - spectrum2)
|
||
|
||
@staticmethod
|
||
def information_divergence(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||
"""信息散度 (Kullback-Leibler divergence)"""
|
||
# 添加小常数避免除零
|
||
eps = 1e-10
|
||
spectrum1 = spectrum1 + eps
|
||
spectrum2 = spectrum2 + eps
|
||
|
||
# 归一化为概率分布
|
||
p = spectrum1 / np.sum(spectrum1)
|
||
q = spectrum2 / np.sum(spectrum2)
|
||
|
||
# 计算KL散度
|
||
kl_pq = np.sum(p * np.log(p / q))
|
||
kl_qp = np.sum(q * np.log(q / p))
|
||
|
||
return (kl_pq + kl_qp) / 2
|
||
|
||
@staticmethod
|
||
def correlation_coefficient(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||
"""相关系数"""
|
||
return np.corrcoef(spectrum1, spectrum2)[0, 1]
|
||
|
||
@staticmethod
|
||
def jeffreys_matusita_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||
"""J-M距离 (Jeffreys-Matusita Distance)"""
|
||
eps = 1e-10
|
||
# 转换为概率分布
|
||
p = (spectrum1 + eps) / np.sum(spectrum1 + eps)
|
||
q = (spectrum2 + eps) / np.sum(spectrum2 + eps)
|
||
|
||
# Bhattacharyya系数
|
||
bc = np.sum(np.sqrt(p * q))
|
||
|
||
# J-M距离 = sqrt(2 * (1 - bc))
|
||
jm_distance = np.sqrt(2 * (1 - bc))
|
||
return jm_distance
|
||
|
||
@staticmethod
|
||
def spectral_angle_mapper(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||
"""光谱角映射器 (Spectral Angle Mapper)"""
|
||
cos_sim = HyperspectralDistanceMetrics.cosine_similarity(spectrum1, spectrum2)
|
||
# 转换为角度(弧度)
|
||
angle = np.arccos(np.clip(cos_sim, -1, 1))
|
||
return angle
|
||
|
||
@staticmethod
|
||
def sid_sa_combined(spectrum1: np.ndarray, spectrum2: np.ndarray,
|
||
alpha: float = 0.5) -> float:
|
||
"""SID_SA结合法 (Spectral Information Divergence + Spectral Angle Mapper)"""
|
||
sid = HyperspectralDistanceMetrics.information_divergence(spectrum1, spectrum2)
|
||
sam = HyperspectralDistanceMetrics.spectral_angle_mapper(spectrum1, spectrum2)
|
||
|
||
# 标准化并结合
|
||
sid_norm = sid / (1 + sid) # 归一化到[0,1)
|
||
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
|
||
|
||
return alpha * sid_norm + (1 - alpha) * sam_norm
|
||
|
||
# ===== 向量化距离计算函数 =====
|
||
|
||
@staticmethod
|
||
def vectorized_cosine_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||
"""向量化余弦距离计算"""
|
||
# X: (n_samples, n_features), centers: (n_clusters, n_features)
|
||
# 返回: (n_samples, n_clusters)
|
||
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
|
||
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
|
||
|
||
# 避免除零
|
||
X_norm = np.where(X_norm == 0, 1, X_norm)
|
||
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
|
||
|
||
# 计算余弦相似度
|
||
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
|
||
# 转换为距离 (1 - 相似度)
|
||
return 1 - np.clip(similarity, -1, 1)
|
||
|
||
@staticmethod
|
||
def vectorized_euclidean_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||
"""向量化欧氏距离计算"""
|
||
# 使用广播计算欧氏距离
|
||
# X: (n_samples, n_features), centers: (n_clusters, n_features)
|
||
# 返回: (n_samples, n_clusters)
|
||
diff = X[:, np.newaxis, :] - centers[np.newaxis, :, :]
|
||
return np.linalg.norm(diff, axis=2)
|
||
|
||
@staticmethod
|
||
def vectorized_correlation_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||
"""向量化相关系数距离计算"""
|
||
# 计算相关系数矩阵
|
||
corr_matrix = np.corrcoef(X.T, centers.T)[:len(X), len(X):]
|
||
# 转换为距离 (1 - |相关系数|)
|
||
return 1 - np.abs(corr_matrix)
|
||
|
||
@staticmethod
|
||
def vectorized_information_divergence(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||
"""向量化信息散度计算"""
|
||
eps = 1e-10
|
||
X_norm = X + eps
|
||
centers_norm = centers + eps
|
||
|
||
# 归一化为概率分布
|
||
X_prob = X_norm / np.sum(X_norm, axis=1, keepdims=True)
|
||
centers_prob = centers_norm / np.sum(centers_norm, axis=1, keepdims=True)
|
||
|
||
# 计算KL散度
|
||
# 使用广播: X_prob[:, np.newaxis, :] - centers_prob[np.newaxis, :, :]
|
||
kl_pq = np.sum(X_prob[:, np.newaxis, :] * np.log(X_prob[:, np.newaxis, :] / centers_prob[np.newaxis, :, :]), axis=2)
|
||
kl_qp = np.sum(centers_prob[np.newaxis, :, :] * np.log(centers_prob[np.newaxis, :, :] / X_prob[:, np.newaxis, :]), axis=2)
|
||
|
||
return (kl_pq + kl_qp) / 2
|
||
|
||
@staticmethod
|
||
def vectorized_jeffreys_matusita_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||
"""向量化J-M距离计算"""
|
||
eps = 1e-10
|
||
|
||
# 转换为概率分布
|
||
X_prob = (X + eps) / np.sum(X + eps, axis=1, keepdims=True)
|
||
centers_prob = (centers + eps) / np.sum(centers + eps, axis=1, keepdims=True)
|
||
|
||
# 计算Bhattacharyya系数
|
||
bc = np.sum(np.sqrt(X_prob[:, np.newaxis, :] * centers_prob[np.newaxis, :, :]), axis=2)
|
||
|
||
# J-M距离
|
||
return np.sqrt(2 * (1 - bc))
|
||
|
||
@staticmethod
|
||
def vectorized_sid_sa_combined(X: np.ndarray, centers: np.ndarray, alpha: float = 0.5) -> np.ndarray:
|
||
"""向量化SID_SA结合距离计算"""
|
||
# 计算SID距离
|
||
sid = HyperspectralDistanceMetrics.vectorized_information_divergence(X, centers)
|
||
|
||
# 计算SAM距离
|
||
sam = HyperspectralDistanceMetrics.vectorized_spectral_angle(X, centers)
|
||
|
||
# 归一化
|
||
sid_norm = sid / (1 + sid) # 归一化到[0,1)
|
||
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
|
||
|
||
# 结合
|
||
return alpha * sid_norm + (1 - alpha) * sam_norm
|
||
|
||
@staticmethod
|
||
def vectorized_spectral_angle(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||
"""向量化光谱角距离计算"""
|
||
# 计算余弦相似度
|
||
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
|
||
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
|
||
|
||
X_norm = np.where(X_norm == 0, 1, X_norm)
|
||
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
|
||
|
||
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
|
||
similarity = np.clip(similarity, -1, 1)
|
||
|
||
# 转换为角度
|
||
return np.arccos(similarity)
|
||
|
||
|
||
class HyperspectralClassification:
|
||
"""高光谱图像监督分类器"""
|
||
|
||
def __init__(self, random_state: int = 42, distance_params: Dict[str, Dict] = None):
|
||
self.random_state = random_state
|
||
self.distance_metrics = HyperspectralDistanceMetrics()
|
||
self.reference_spectra_ = None
|
||
self.labels_ = None
|
||
self.class_names_ = None
|
||
|
||
# 默认距离度量参数
|
||
self.default_distance_params = {
|
||
'cosine': {},
|
||
'euclidean': {},
|
||
'correlation': {},
|
||
'information_divergence': {},
|
||
'sam': {},
|
||
'jm_distance': {'alpha': 0.5}, # Bhattacharyya系数权重
|
||
'sid_sa': {'alpha': 0.5} # SID和SAM的权重
|
||
}
|
||
|
||
# 更新用户提供的参数
|
||
self.distance_params = self.default_distance_params.copy()
|
||
if distance_params:
|
||
self.distance_params.update(distance_params)
|
||
|
||
# 更新向量化距离函数,添加新的方法
|
||
self.vectorized_distance_functions = {
|
||
'cosine': self.distance_metrics.vectorized_cosine_distance,
|
||
'euclidean': self.distance_metrics.vectorized_euclidean_distance,
|
||
'correlation': self.distance_metrics.vectorized_correlation_distance,
|
||
'information_divergence': self.distance_metrics.vectorized_information_divergence,
|
||
'sam': self.distance_metrics.vectorized_spectral_angle,
|
||
'jm_distance': self.distance_metrics.vectorized_jeffreys_matusita_distance, # 新增
|
||
'sid_sa': self.distance_metrics.vectorized_sid_sa_combined # 新增
|
||
}
|
||
|
||
def fit_predict_with_references(self, X: np.ndarray,
|
||
reference_spectra: Dict[str, np.ndarray],
|
||
method: str = 'euclidean') -> np.ndarray:
|
||
"""
|
||
使用参考光谱进行监督分类
|
||
|
||
Parameters:
|
||
X: 高光谱数据,形状为 (n_samples, n_bands) 或 (rows, cols, n_bands)
|
||
reference_spectra: 参考光谱字典,键为类别名,值为光谱 (bands,)
|
||
method: 距离度量方法
|
||
|
||
Returns:
|
||
分类结果,形状与X相同
|
||
"""
|
||
if method not in self.vectorized_distance_functions:
|
||
raise ValueError(f"不支持的距离度量方法: {method}")
|
||
|
||
if not reference_spectra:
|
||
raise ValueError("必须提供参考光谱")
|
||
|
||
print(f"DEBUG fit_predict_with_references: X.shape = {X.shape}")
|
||
print(f"DEBUG fit_predict_with_references: reference_spectra keys = {list(reference_spectra.keys())}")
|
||
for name, spectrum in reference_spectra.items():
|
||
print(f"DEBUG fit_predict_with_references: {name}.shape = {spectrum.shape}")
|
||
|
||
# 处理3D输入 (图像数据)
|
||
if X.ndim == 3:
|
||
rows, cols, bands = X.shape
|
||
X_reshaped = X.reshape(-1, bands)
|
||
# 移除NaN和无穷大值
|
||
valid_mask = np.isfinite(X_reshaped).all(axis=1)
|
||
X_valid = X_reshaped[valid_mask]
|
||
else:
|
||
X_reshaped = X
|
||
X_valid = X_reshaped
|
||
rows, cols = None, None
|
||
|
||
if len(X_valid) == 0:
|
||
raise ValueError("没有有效的光谱数据")
|
||
|
||
# 转换为参考光谱数组 - 确保是二维数组 (n_classes, n_bands)
|
||
class_names = list(reference_spectra.keys())
|
||
|
||
# 首先检查并确保每个参考光谱都是一维的
|
||
cleaned_reference_spectra = []
|
||
for name in class_names:
|
||
spectrum = reference_spectra[name]
|
||
if spectrum.ndim > 1:
|
||
spectrum = spectrum.flatten()
|
||
cleaned_reference_spectra.append(spectrum)
|
||
|
||
reference_array = np.array(cleaned_reference_spectra) # 现在应该是 (n_classes, n_bands)
|
||
print(f"DEBUG fit_predict_with_references: reference_array.shape = {reference_array.shape}")
|
||
|
||
# 进行分类
|
||
labels = self._classify_pixels(X_valid, reference_array, method)
|
||
|
||
# 如果是3D输入,需要重塑回图像形状
|
||
if X.ndim == 3:
|
||
result = np.full((rows * cols,), -1, dtype=int)
|
||
result[valid_mask] = labels
|
||
result = result.reshape(rows, cols)
|
||
else:
|
||
result = labels
|
||
|
||
self.reference_spectra_ = reference_spectra
|
||
self.class_names_ = class_names
|
||
self.labels_ = result
|
||
return result
|
||
|
||
def _classify_pixels(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
|
||
"""
|
||
使用参考光谱对像素进行分类
|
||
|
||
Parameters:
|
||
X: 像素光谱数据 (n_samples, n_bands)
|
||
reference_spectra: 参考光谱 (n_classes, n_bands)
|
||
method: 距离度量方法
|
||
|
||
Returns:
|
||
分类标签 (n_samples,)
|
||
"""
|
||
print(f"DEBUG: X.shape = {X.shape}")
|
||
print(f"DEBUG: reference_spectra.shape = {reference_spectra.shape}")
|
||
print(f"DEBUG: method = {method}")
|
||
|
||
|
||
|
||
# 检查是否有向量化距离函数
|
||
if method not in self.vectorized_distance_functions or self.vectorized_distance_functions[method] is None:
|
||
raise ValueError(f"不支持向量化计算的距离度量方法: {method}")
|
||
|
||
# 使用向量化距离计算
|
||
distances = self.vectorized_distance_functions[method](X, reference_spectra)
|
||
|
||
# 为每个像素分配最近的参考光谱类别
|
||
labels = np.argmin(distances, axis=1)
|
||
|
||
return labels
|
||
|
||
def _classify_with_complex_distance(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
|
||
"""
|
||
使用复杂距离度量进行分类(非向量化)
|
||
"""
|
||
print(f"使用 {method} 距离度量进行分类...")
|
||
|
||
n_samples = X.shape[0]
|
||
n_classes = reference_spectra.shape[0]
|
||
labels = np.zeros(n_samples, dtype=int)
|
||
|
||
# 为每个像素计算与所有参考光谱的距离
|
||
for i in range(n_samples):
|
||
pixel = X[i]
|
||
min_distance = float('inf')
|
||
best_class = 0
|
||
|
||
for j in range(n_classes):
|
||
reference = reference_spectra[j]
|
||
|
||
if method == 'jm_distance':
|
||
distance = self.distance_metrics.jeffreys_matusita_distance(pixel, reference)
|
||
elif method == 'sid_sa':
|
||
alpha = self.distance_params[method].get('alpha', 0.5)
|
||
distance = self.distance_metrics.sid_sa_combined(pixel, reference, alpha)
|
||
else:
|
||
# 默认使用欧氏距离
|
||
distance = self.distance_metrics.euclidean_distance(pixel, reference)
|
||
|
||
if distance < min_distance:
|
||
min_distance = distance
|
||
best_class = j
|
||
|
||
labels[i] = best_class
|
||
|
||
return labels
|
||
|
||
def fit_predict_all_methods(self, X: np.ndarray, reference_spectra: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||
"""使用所有距离度量方法进行分类"""
|
||
results = {}
|
||
methods = list(self.vectorized_distance_functions.keys())
|
||
|
||
print("开始使用不同距离度量进行分类...")
|
||
for method in methods:
|
||
print(f"正在使用 {method} 距离度量...")
|
||
try:
|
||
results[method] = self.fit_predict_with_references(X, reference_spectra, method)
|
||
print(f"✓ {method} 分类完成")
|
||
except Exception as e:
|
||
print(f"✗ {method} 分类失败: {e}")
|
||
results[method] = None
|
||
|
||
return results
|
||
|
||
|
||
class HyperspectralImageProcessor:
|
||
"""高光谱图像处理类"""
|
||
|
||
def __init__(self):
|
||
self.data = None
|
||
self.header = None
|
||
self.wavelengths = None
|
||
|
||
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
|
||
|
||
try:
|
||
# 读取ENVI文件
|
||
img = envi.open(hdr_path)
|
||
data = img.load()
|
||
header = dict(img.metadata)
|
||
|
||
# 提取波长信息
|
||
if 'wavelength' in header:
|
||
wavelengths = np.array([float(w) for w in header['wavelength']])
|
||
else:
|
||
wavelengths = None
|
||
|
||
self.data = data
|
||
self.header = header
|
||
self.wavelengths = wavelengths
|
||
|
||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||
if wavelengths is not None:
|
||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||
|
||
return data, header
|
||
|
||
except Exception as e:
|
||
raise IOError(f"加载图像失败: {e}")
|
||
|
||
def parse_roi_xml(self, xml_path: str) -> Dict[str, List[Tuple[float, float]]]:
|
||
"""
|
||
解析ENVI ROI XML文件,提取每个区域的坐标
|
||
|
||
Parameters:
|
||
xml_path: XML文件路径
|
||
|
||
Returns:
|
||
字典,键为ROI名称,值为坐标列表 [(x1,y1), (x2,y2), ...]
|
||
"""
|
||
try:
|
||
tree = ET.parse(xml_path)
|
||
root = tree.getroot()
|
||
|
||
roi_coordinates = {}
|
||
|
||
for region in root.findall('Region'):
|
||
name = region.get('name')
|
||
coordinates_text = None
|
||
|
||
# 查找Coordinates元素
|
||
coord_elem = region.find('.//Coordinates')
|
||
if coord_elem is not None and coord_elem.text:
|
||
coordinates_text = coord_elem.text.strip()
|
||
|
||
if coordinates_text:
|
||
# 解析坐标字符串
|
||
coords = []
|
||
parts = coordinates_text.split()
|
||
# 坐标是成对的 (x y x y ...)
|
||
for i in range(0, len(parts), 2):
|
||
if i + 1 < len(parts):
|
||
x = float(parts[i])
|
||
y = float(parts[i + 1])
|
||
coords.append((x, y))
|
||
|
||
roi_coordinates[name] = coords
|
||
|
||
return roi_coordinates
|
||
|
||
except Exception as e:
|
||
raise IOError(f"解析XML文件失败: {e}")
|
||
|
||
def extract_reference_spectra(self, data: np.ndarray,
|
||
roi_coordinates: Dict[str, List[Tuple[float, float]]]) -> Dict[str, np.ndarray]:
|
||
"""
|
||
从图像中提取参考光谱
|
||
|
||
Parameters:
|
||
data: 高光谱图像数据 (rows, cols, bands)
|
||
roi_coordinates: ROI坐标字典
|
||
|
||
Returns:
|
||
字典,键为ROI名称,值为平均光谱 (bands,)
|
||
"""
|
||
reference_spectra = {}
|
||
rows, cols, bands = data.shape
|
||
|
||
for roi_name, coords in roi_coordinates.items():
|
||
# 创建多边形
|
||
polygon = Polygon(coords)
|
||
|
||
# 找到多边形内的所有像素
|
||
pixels_in_roi = []
|
||
for i in range(rows):
|
||
for j in range(cols):
|
||
point = Point(j, i) # 注意:ENVI坐标系中x是列,y是行
|
||
if polygon.contains(point):
|
||
pixels_in_roi.append((i, j))
|
||
|
||
if len(pixels_in_roi) == 0:
|
||
print(f"警告: ROI '{roi_name}' 中没有找到像素")
|
||
continue
|
||
|
||
# 提取像素光谱
|
||
spectra = []
|
||
for i, j in pixels_in_roi:
|
||
# 修正:确保提取的是一维光谱
|
||
spectrum = data[i, j, :]
|
||
# 确保是一维数组
|
||
if spectrum.ndim > 1:
|
||
spectrum = spectrum.flatten()
|
||
|
||
if np.isfinite(spectrum).all() and spectrum.size == bands: # 只使用有效光谱
|
||
spectra.append(spectrum)
|
||
|
||
if len(spectra) == 0:
|
||
print(f"警告: ROI '{roi_name}' 中没有有效的光谱数据")
|
||
continue
|
||
|
||
# 计算平均光谱 - 确保结果是一维
|
||
avg_spectrum = np.mean(spectra, axis=0)
|
||
if avg_spectrum.ndim > 1:
|
||
avg_spectrum = avg_spectrum.flatten()
|
||
|
||
reference_spectra[roi_name] = avg_spectrum
|
||
|
||
print(f"ROI '{roi_name}': 找到 {len(spectra)} 个有效像素, 光谱维度 {avg_spectrum.shape}")
|
||
|
||
return reference_spectra
|
||
|
||
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
|
||
"""加载ENVI格式高光谱图像"""
|
||
try:
|
||
# 读取ENVI文件
|
||
img = envi.open(hdr_path)
|
||
data = img.load()
|
||
header = dict(img.metadata)
|
||
|
||
# 提取波长信息
|
||
if 'wavelength' in header:
|
||
wavelengths = np.array([float(w) for w in header['wavelength']])
|
||
else:
|
||
wavelengths = None
|
||
|
||
self.data = data
|
||
self.header = header
|
||
self.wavelengths = wavelengths
|
||
|
||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||
if wavelengths is not None:
|
||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||
|
||
return data, header
|
||
|
||
except Exception as e:
|
||
raise IOError(f"加载图像失败: {e}")
|
||
|
||
def save_single_band_image(self, data: np.ndarray, output_path: str,
|
||
method_name: str = "", original_header: Dict = None) -> None:
|
||
"""保存单波段聚类结果图像为dat和hdr文件"""
|
||
try:
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
|
||
# 确保输出文件扩展名为.dat
|
||
if not output_path.lower().endswith('.dat'):
|
||
output_path = output_path.rsplit('.', 1)[0] + '.dat'
|
||
|
||
# 将数据转换为合适的格式(聚类标签通常使用int32)
|
||
if data.dtype != np.int32:
|
||
data_to_save = data.astype(np.int32)
|
||
else:
|
||
data_to_save = data
|
||
|
||
# 保存为二进制dat文件
|
||
data_to_save.tofile(output_path)
|
||
|
||
# 创建对应的hdr文件
|
||
hdr_path = output_path.rsplit('.', 1)[0] + '.hdr'
|
||
self._create_cluster_hdr_file(hdr_path, data_to_save.shape, method_name, original_header)
|
||
|
||
print(f"聚类结果已保存到:")
|
||
print(f" 数据文件: {output_path}")
|
||
print(f" 头文件: {hdr_path}")
|
||
print(f"数据类型: {data_to_save.dtype}, 形状: {data_to_save.shape}")
|
||
|
||
except Exception as e:
|
||
raise IOError(f"保存图像失败: {e}")
|
||
|
||
def _create_cluster_hdr_file(self, hdr_path: str, data_shape: Tuple,
|
||
method_name: str, original_header: Dict = None) -> None:
|
||
"""创建聚类结果的ENVI头文件"""
|
||
try:
|
||
# 从原始头文件获取基本信息
|
||
if original_header is not None:
|
||
# 复制原始头文件的关键信息
|
||
samples = original_header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
|
||
lines = original_header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
|
||
bands = 1 # 聚类结果是单波段
|
||
interleave = 'bip'
|
||
data_type = 3 # ENVI数据类型: 3=int32, 4=float32
|
||
byte_order = original_header.get('byte order', 0)
|
||
wavelength_units = original_header.get('wavelength units', 'nm')
|
||
else:
|
||
# 默认值(适用于2D聚类结果)
|
||
if len(data_shape) == 2:
|
||
lines, samples = data_shape
|
||
else:
|
||
lines = data_shape[0]
|
||
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
|
||
bands = 1
|
||
interleave = 'bsq'
|
||
data_type = 3 # int32
|
||
byte_order = 0
|
||
wavelength_units = 'Unknown'
|
||
|
||
# 写入hdr文件
|
||
with open(hdr_path, 'w') as f:
|
||
f.write("ENVI\n")
|
||
f.write("description = {\n")
|
||
f.write(f" 聚类结果 - {method_name}\n")
|
||
f.write(" 单波段聚类标签图像\n")
|
||
f.write("}\n")
|
||
f.write(f"samples = {samples}\n")
|
||
f.write(f"lines = {lines}\n")
|
||
f.write(f"bands = {bands}\n")
|
||
f.write(f"header offset = 0\n")
|
||
f.write(f"file type = ENVI Standard\n")
|
||
f.write(f"data type = {data_type}\n")
|
||
f.write(f"interleave = {interleave}\n")
|
||
f.write(f"byte order = {byte_order}\n")
|
||
|
||
# 如果有波长信息,添加虚拟波长
|
||
if original_header and 'wavelength' in original_header:
|
||
f.write("wavelength = {\n")
|
||
f.write(" 类别标签\n")
|
||
f.write("}\n")
|
||
f.write(f"wavelength units = {wavelength_units}\n")
|
||
|
||
# 添加类别信息注释
|
||
f.write("band names = {\n")
|
||
f.write(f" 聚类结果_{method_name}\n")
|
||
f.write("}\n")
|
||
|
||
print(f"成功创建头文件: {hdr_path}")
|
||
|
||
except Exception as e:
|
||
print(f"创建头文件失败: {e}")
|
||
|
||
def visualize_clusters(self, cluster_results: Dict[str, np.ndarray],
|
||
save_path: Optional[str] = None) -> None:
|
||
"""可视化聚类结果"""
|
||
n_methods = len(cluster_results)
|
||
fig, axes = plt.subplots(2, (n_methods + 1) // 2, figsize=(15, 10))
|
||
axes = axes.flatten()
|
||
|
||
for i, (method, result) in enumerate(cluster_results.items()):
|
||
if result is not None and i < len(axes):
|
||
im = axes[i].imshow(result, cmap='tab10')
|
||
axes[i].set_title(f'{method.replace("_", " ").title()}')
|
||
axes[i].axis('off')
|
||
plt.colorbar(im, ax=axes[i], shrink=0.8)
|
||
|
||
# 隐藏多余的子图
|
||
for j in range(i + 1, len(axes)):
|
||
axes[j].axis('off')
|
||
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
print(f"可视化结果已保存到: {save_path}")
|
||
|
||
plt.show()
|
||
|
||
|
||
def main():
|
||
"""主函数:处理高光谱图像监督分类"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='高光谱图像监督分类分析')
|
||
parser.add_argument('input_file', help='输入ENVI格式高光谱图像文件 (.hdr)')
|
||
parser.add_argument('xml_file', help='ENVI ROI XML文件路径')
|
||
parser.add_argument('--output_dir', '-o', default='output',
|
||
help='输出目录 (默认: output)')
|
||
parser.add_argument('--method', '-m', default='all',
|
||
choices=['all', 'euclidean', 'cosine', 'correlation',
|
||
'information_divergence', 'jm_distance', 'sid_sa'],
|
||
help='分类距离度量方法 (默认: all,使用所有方法)')
|
||
parser.add_argument('--distance_params', '-p', type=str, default=None,
|
||
help='距离度量方法的超参数,JSON格式字符串,例如: {"jm_distance": {"alpha": 0.7}}')
|
||
parser.add_argument('--visualize', '-v', action='store_true',
|
||
help='是否生成可视化结果')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 解析距离参数
|
||
distance_params = None
|
||
if args.distance_params:
|
||
import json
|
||
try:
|
||
distance_params = json.loads(args.distance_params)
|
||
except json.JSONDecodeError:
|
||
print(f"警告: 无法解析距离参数 '{args.distance_params}',使用默认参数")
|
||
distance_params = None
|
||
|
||
try:
|
||
# 调用分类函数
|
||
return run_hsi_classification(
|
||
input_file=args.input_file,
|
||
xml_file=args.xml_file,
|
||
output_dir=args.output_dir,
|
||
method=args.method,
|
||
distance_params=distance_params,
|
||
visualize=args.visualize
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"✗ 处理失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
|
||
|
||
|
||
def run_hsi_classification(input_file, xml_file, output_dir='output', method='all',
|
||
distance_params=None, visualize=False):
|
||
"""
|
||
执行高光谱图像监督分类分析
|
||
|
||
参数:
|
||
input_file: 输入ENVI格式高光谱图像文件 (.hdr)
|
||
xml_file: ENVI ROI XML文件路径
|
||
output_dir: 输出目录 (默认: output)
|
||
method: 分类距离度量方法 ('all' 或具体方法名,默认: 'all')
|
||
distance_params: 距离度量方法的超参数字典 (默认: None)
|
||
visualize: 是否生成可视化结果 (默认: False)
|
||
|
||
返回:
|
||
成功返回0,失败返回1
|
||
"""
|
||
try:
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 初始化处理器
|
||
processor = HyperspectralImageProcessor()
|
||
|
||
# 加载图像
|
||
print(f"加载高光谱图像: {input_file}")
|
||
data, header = processor.load_image(input_file)
|
||
|
||
# 解析XML文件并提取参考光谱
|
||
print(f"\n解析ROI XML文件: {xml_file}")
|
||
roi_coordinates = processor.parse_roi_xml(xml_file)
|
||
|
||
if not roi_coordinates:
|
||
raise ValueError("XML文件中没有找到有效的ROI区域")
|
||
|
||
print(f"找到 {len(roi_coordinates)} 个ROI区域")
|
||
for roi_name, coords in roi_coordinates.items():
|
||
print(f" {roi_name}: {len(coords)} 个顶点")
|
||
|
||
# 从图像中提取参考光谱
|
||
print(f"\n图像数据形状: {data.shape}")
|
||
print("提取参考光谱...")
|
||
reference_spectra = processor.extract_reference_spectra(data, roi_coordinates)
|
||
|
||
if not reference_spectra:
|
||
raise ValueError("无法从ROI区域提取有效的参考光谱")
|
||
|
||
print(f"成功提取 {len(reference_spectra)} 个参考光谱:")
|
||
for roi_name, spectrum in reference_spectra.items():
|
||
print(f" {roi_name}: 形状 {spectrum.shape}, 数据类型 {spectrum.dtype}")
|
||
|
||
# 初始化分类器
|
||
classifier = HyperspectralClassification(distance_params=distance_params)
|
||
|
||
# 根据指定方法进行分类
|
||
print(f"\n开始分类分析 (类别数: {len(reference_spectra)}, 方法: {method})...")
|
||
if method == 'all':
|
||
results = classifier.fit_predict_all_methods(data, reference_spectra)
|
||
else:
|
||
# 使用指定方法
|
||
result = classifier.fit_predict_with_references(data, reference_spectra, method)
|
||
results = {method: result}
|
||
|
||
# 保存结果
|
||
print("\n保存分类结果...")
|
||
saved_files = []
|
||
for method, result in results.items():
|
||
if result is not None:
|
||
output_file = os.path.join(output_dir, f'classification_{method}.dat')
|
||
processor.save_single_band_image(result, output_file, method, header)
|
||
saved_files.append((method, output_file))
|
||
|
||
# 输出统计信息
|
||
print("\n=== 分类统计信息 ===")
|
||
for method, result in results.items():
|
||
if result is not None:
|
||
unique_labels = np.unique(result[result >= 0])
|
||
n_classes_found = len(unique_labels)
|
||
print(f"\n{method}: 识别 {n_classes_found} 个类别")
|
||
|
||
# 显示每个类别的像素数量
|
||
for class_idx, class_name in enumerate(classifier.class_names_):
|
||
pixel_count = np.sum(result == class_idx)
|
||
print(f" {class_name}: {pixel_count} 个像素")
|
||
|
||
print("\n✓ 分类分析完成!")
|
||
print(f"输出目录: {output_dir}")
|
||
print(f"生成文件: {len(saved_files)} 个")
|
||
return 0
|
||
|
||
except Exception as e:
|
||
print(f"✗ 处理失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
if __name__ == '__main__':
|
||
exit(main()) |