增加模块;增加主调用命令
This commit is contained in:
815
supervize_cluster_method/supervize_cluster.py
Normal file
815
supervize_cluster_method/supervize_cluster.py
Normal file
@ -0,0 +1,815 @@
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
from scipy.spatial.distance import pdist, squareform
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import spectral as spy
|
||||
from spectral import envi
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from typing import Tuple, List, Dict, Optional, Union
|
||||
import warnings
|
||||
import xml.etree.ElementTree as ET
|
||||
from shapely.geometry import Polygon, Point
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class HyperspectralDistanceMetrics:
|
||||
"""高光谱数据距离度量类"""
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""余弦相似度"""
|
||||
dot_product = np.dot(spectrum1, spectrum2)
|
||||
norm1 = np.linalg.norm(spectrum1)
|
||||
norm2 = np.linalg.norm(spectrum2)
|
||||
return dot_product / (norm1 * norm2) if norm1 != 0 and norm2 != 0 else 0
|
||||
|
||||
@staticmethod
|
||||
def euclidean_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""欧氏距离"""
|
||||
return np.linalg.norm(spectrum1 - spectrum2)
|
||||
|
||||
@staticmethod
|
||||
def information_divergence(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""信息散度 (Kullback-Leibler divergence)"""
|
||||
# 添加小常数避免除零
|
||||
eps = 1e-10
|
||||
spectrum1 = spectrum1 + eps
|
||||
spectrum2 = spectrum2 + eps
|
||||
|
||||
# 归一化为概率分布
|
||||
p = spectrum1 / np.sum(spectrum1)
|
||||
q = spectrum2 / np.sum(spectrum2)
|
||||
|
||||
# 计算KL散度
|
||||
kl_pq = np.sum(p * np.log(p / q))
|
||||
kl_qp = np.sum(q * np.log(q / p))
|
||||
|
||||
return (kl_pq + kl_qp) / 2
|
||||
|
||||
@staticmethod
|
||||
def correlation_coefficient(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""相关系数"""
|
||||
return np.corrcoef(spectrum1, spectrum2)[0, 1]
|
||||
|
||||
@staticmethod
|
||||
def jeffreys_matusita_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""J-M距离 (Jeffreys-Matusita Distance)"""
|
||||
eps = 1e-10
|
||||
# 转换为概率分布
|
||||
p = (spectrum1 + eps) / np.sum(spectrum1 + eps)
|
||||
q = (spectrum2 + eps) / np.sum(spectrum2 + eps)
|
||||
|
||||
# Bhattacharyya系数
|
||||
bc = np.sum(np.sqrt(p * q))
|
||||
|
||||
# J-M距离 = sqrt(2 * (1 - bc))
|
||||
jm_distance = np.sqrt(2 * (1 - bc))
|
||||
return jm_distance
|
||||
|
||||
@staticmethod
|
||||
def spectral_angle_mapper(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
|
||||
"""光谱角映射器 (Spectral Angle Mapper)"""
|
||||
cos_sim = HyperspectralDistanceMetrics.cosine_similarity(spectrum1, spectrum2)
|
||||
# 转换为角度(弧度)
|
||||
angle = np.arccos(np.clip(cos_sim, -1, 1))
|
||||
return angle
|
||||
|
||||
@staticmethod
|
||||
def sid_sa_combined(spectrum1: np.ndarray, spectrum2: np.ndarray,
|
||||
alpha: float = 0.5) -> float:
|
||||
"""SID_SA结合法 (Spectral Information Divergence + Spectral Angle Mapper)"""
|
||||
sid = HyperspectralDistanceMetrics.information_divergence(spectrum1, spectrum2)
|
||||
sam = HyperspectralDistanceMetrics.spectral_angle_mapper(spectrum1, spectrum2)
|
||||
|
||||
# 标准化并结合
|
||||
sid_norm = sid / (1 + sid) # 归一化到[0,1)
|
||||
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
|
||||
|
||||
return alpha * sid_norm + (1 - alpha) * sam_norm
|
||||
|
||||
# ===== 向量化距离计算函数 =====
|
||||
|
||||
@staticmethod
|
||||
def vectorized_cosine_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化余弦距离计算"""
|
||||
# X: (n_samples, n_features), centers: (n_clusters, n_features)
|
||||
# 返回: (n_samples, n_clusters)
|
||||
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
|
||||
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
|
||||
|
||||
# 避免除零
|
||||
X_norm = np.where(X_norm == 0, 1, X_norm)
|
||||
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
|
||||
# 转换为距离 (1 - 相似度)
|
||||
return 1 - np.clip(similarity, -1, 1)
|
||||
|
||||
@staticmethod
|
||||
def vectorized_euclidean_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化欧氏距离计算"""
|
||||
# 使用广播计算欧氏距离
|
||||
# X: (n_samples, n_features), centers: (n_clusters, n_features)
|
||||
# 返回: (n_samples, n_clusters)
|
||||
diff = X[:, np.newaxis, :] - centers[np.newaxis, :, :]
|
||||
return np.linalg.norm(diff, axis=2)
|
||||
|
||||
@staticmethod
|
||||
def vectorized_correlation_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化相关系数距离计算"""
|
||||
# 计算相关系数矩阵
|
||||
corr_matrix = np.corrcoef(X.T, centers.T)[:len(X), len(X):]
|
||||
# 转换为距离 (1 - |相关系数|)
|
||||
return 1 - np.abs(corr_matrix)
|
||||
|
||||
@staticmethod
|
||||
def vectorized_information_divergence(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化信息散度计算"""
|
||||
eps = 1e-10
|
||||
X_norm = X + eps
|
||||
centers_norm = centers + eps
|
||||
|
||||
# 归一化为概率分布
|
||||
X_prob = X_norm / np.sum(X_norm, axis=1, keepdims=True)
|
||||
centers_prob = centers_norm / np.sum(centers_norm, axis=1, keepdims=True)
|
||||
|
||||
# 计算KL散度
|
||||
# 使用广播: X_prob[:, np.newaxis, :] - centers_prob[np.newaxis, :, :]
|
||||
kl_pq = np.sum(X_prob[:, np.newaxis, :] * np.log(X_prob[:, np.newaxis, :] / centers_prob[np.newaxis, :, :]), axis=2)
|
||||
kl_qp = np.sum(centers_prob[np.newaxis, :, :] * np.log(centers_prob[np.newaxis, :, :] / X_prob[:, np.newaxis, :]), axis=2)
|
||||
|
||||
return (kl_pq + kl_qp) / 2
|
||||
|
||||
@staticmethod
|
||||
def vectorized_jeffreys_matusita_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化J-M距离计算"""
|
||||
eps = 1e-10
|
||||
|
||||
# 转换为概率分布
|
||||
X_prob = (X + eps) / np.sum(X + eps, axis=1, keepdims=True)
|
||||
centers_prob = (centers + eps) / np.sum(centers + eps, axis=1, keepdims=True)
|
||||
|
||||
# 计算Bhattacharyya系数
|
||||
bc = np.sum(np.sqrt(X_prob[:, np.newaxis, :] * centers_prob[np.newaxis, :, :]), axis=2)
|
||||
|
||||
# J-M距离
|
||||
return np.sqrt(2 * (1 - bc))
|
||||
|
||||
@staticmethod
|
||||
def vectorized_sid_sa_combined(X: np.ndarray, centers: np.ndarray, alpha: float = 0.5) -> np.ndarray:
|
||||
"""向量化SID_SA结合距离计算"""
|
||||
# 计算SID距离
|
||||
sid = HyperspectralDistanceMetrics.vectorized_information_divergence(X, centers)
|
||||
|
||||
# 计算SAM距离
|
||||
sam = HyperspectralDistanceMetrics.vectorized_spectral_angle(X, centers)
|
||||
|
||||
# 归一化
|
||||
sid_norm = sid / (1 + sid) # 归一化到[0,1)
|
||||
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
|
||||
|
||||
# 结合
|
||||
return alpha * sid_norm + (1 - alpha) * sam_norm
|
||||
|
||||
@staticmethod
|
||||
def vectorized_spectral_angle(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
|
||||
"""向量化光谱角距离计算"""
|
||||
# 计算余弦相似度
|
||||
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
|
||||
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
|
||||
|
||||
X_norm = np.where(X_norm == 0, 1, X_norm)
|
||||
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
|
||||
|
||||
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
|
||||
similarity = np.clip(similarity, -1, 1)
|
||||
|
||||
# 转换为角度
|
||||
return np.arccos(similarity)
|
||||
|
||||
|
||||
class HyperspectralClassification:
|
||||
"""高光谱图像监督分类器"""
|
||||
|
||||
def __init__(self, random_state: int = 42, distance_params: Dict[str, Dict] = None):
|
||||
self.random_state = random_state
|
||||
self.distance_metrics = HyperspectralDistanceMetrics()
|
||||
self.reference_spectra_ = None
|
||||
self.labels_ = None
|
||||
self.class_names_ = None
|
||||
|
||||
# 默认距离度量参数
|
||||
self.default_distance_params = {
|
||||
'cosine': {},
|
||||
'euclidean': {},
|
||||
'correlation': {},
|
||||
'information_divergence': {},
|
||||
'sam': {},
|
||||
'jm_distance': {'alpha': 0.5}, # Bhattacharyya系数权重
|
||||
'sid_sa': {'alpha': 0.5} # SID和SAM的权重
|
||||
}
|
||||
|
||||
# 更新用户提供的参数
|
||||
self.distance_params = self.default_distance_params.copy()
|
||||
if distance_params:
|
||||
self.distance_params.update(distance_params)
|
||||
|
||||
# 更新向量化距离函数,添加新的方法
|
||||
self.vectorized_distance_functions = {
|
||||
'cosine': self.distance_metrics.vectorized_cosine_distance,
|
||||
'euclidean': self.distance_metrics.vectorized_euclidean_distance,
|
||||
'correlation': self.distance_metrics.vectorized_correlation_distance,
|
||||
'information_divergence': self.distance_metrics.vectorized_information_divergence,
|
||||
'sam': self.distance_metrics.vectorized_spectral_angle,
|
||||
'jm_distance': self.distance_metrics.vectorized_jeffreys_matusita_distance, # 新增
|
||||
'sid_sa': self.distance_metrics.vectorized_sid_sa_combined # 新增
|
||||
}
|
||||
|
||||
def fit_predict_with_references(self, X: np.ndarray,
|
||||
reference_spectra: Dict[str, np.ndarray],
|
||||
method: str = 'euclidean') -> np.ndarray:
|
||||
"""
|
||||
使用参考光谱进行监督分类
|
||||
|
||||
Parameters:
|
||||
X: 高光谱数据,形状为 (n_samples, n_bands) 或 (rows, cols, n_bands)
|
||||
reference_spectra: 参考光谱字典,键为类别名,值为光谱 (bands,)
|
||||
method: 距离度量方法
|
||||
|
||||
Returns:
|
||||
分类结果,形状与X相同
|
||||
"""
|
||||
if method not in self.vectorized_distance_functions:
|
||||
raise ValueError(f"不支持的距离度量方法: {method}")
|
||||
|
||||
if not reference_spectra:
|
||||
raise ValueError("必须提供参考光谱")
|
||||
|
||||
print(f"DEBUG fit_predict_with_references: X.shape = {X.shape}")
|
||||
print(f"DEBUG fit_predict_with_references: reference_spectra keys = {list(reference_spectra.keys())}")
|
||||
for name, spectrum in reference_spectra.items():
|
||||
print(f"DEBUG fit_predict_with_references: {name}.shape = {spectrum.shape}")
|
||||
|
||||
# 处理3D输入 (图像数据)
|
||||
if X.ndim == 3:
|
||||
rows, cols, bands = X.shape
|
||||
X_reshaped = X.reshape(-1, bands)
|
||||
# 移除NaN和无穷大值
|
||||
valid_mask = np.isfinite(X_reshaped).all(axis=1)
|
||||
X_valid = X_reshaped[valid_mask]
|
||||
else:
|
||||
X_reshaped = X
|
||||
X_valid = X_reshaped
|
||||
rows, cols = None, None
|
||||
|
||||
if len(X_valid) == 0:
|
||||
raise ValueError("没有有效的光谱数据")
|
||||
|
||||
# 转换为参考光谱数组 - 确保是二维数组 (n_classes, n_bands)
|
||||
class_names = list(reference_spectra.keys())
|
||||
|
||||
# 首先检查并确保每个参考光谱都是一维的
|
||||
cleaned_reference_spectra = []
|
||||
for name in class_names:
|
||||
spectrum = reference_spectra[name]
|
||||
if spectrum.ndim > 1:
|
||||
spectrum = spectrum.flatten()
|
||||
cleaned_reference_spectra.append(spectrum)
|
||||
|
||||
reference_array = np.array(cleaned_reference_spectra) # 现在应该是 (n_classes, n_bands)
|
||||
print(f"DEBUG fit_predict_with_references: reference_array.shape = {reference_array.shape}")
|
||||
|
||||
# 进行分类
|
||||
labels = self._classify_pixels(X_valid, reference_array, method)
|
||||
|
||||
# 如果是3D输入,需要重塑回图像形状
|
||||
if X.ndim == 3:
|
||||
result = np.full((rows * cols,), -1, dtype=int)
|
||||
result[valid_mask] = labels
|
||||
result = result.reshape(rows, cols)
|
||||
else:
|
||||
result = labels
|
||||
|
||||
self.reference_spectra_ = reference_spectra
|
||||
self.class_names_ = class_names
|
||||
self.labels_ = result
|
||||
return result
|
||||
|
||||
def _classify_pixels(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
|
||||
"""
|
||||
使用参考光谱对像素进行分类
|
||||
|
||||
Parameters:
|
||||
X: 像素光谱数据 (n_samples, n_bands)
|
||||
reference_spectra: 参考光谱 (n_classes, n_bands)
|
||||
method: 距离度量方法
|
||||
|
||||
Returns:
|
||||
分类标签 (n_samples,)
|
||||
"""
|
||||
print(f"DEBUG: X.shape = {X.shape}")
|
||||
print(f"DEBUG: reference_spectra.shape = {reference_spectra.shape}")
|
||||
print(f"DEBUG: method = {method}")
|
||||
|
||||
|
||||
|
||||
# 检查是否有向量化距离函数
|
||||
if method not in self.vectorized_distance_functions or self.vectorized_distance_functions[method] is None:
|
||||
raise ValueError(f"不支持向量化计算的距离度量方法: {method}")
|
||||
|
||||
# 使用向量化距离计算
|
||||
distances = self.vectorized_distance_functions[method](X, reference_spectra)
|
||||
|
||||
# 为每个像素分配最近的参考光谱类别
|
||||
labels = np.argmin(distances, axis=1)
|
||||
|
||||
return labels
|
||||
|
||||
def _classify_with_complex_distance(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
|
||||
"""
|
||||
使用复杂距离度量进行分类(非向量化)
|
||||
"""
|
||||
print(f"使用 {method} 距离度量进行分类...")
|
||||
|
||||
n_samples = X.shape[0]
|
||||
n_classes = reference_spectra.shape[0]
|
||||
labels = np.zeros(n_samples, dtype=int)
|
||||
|
||||
# 为每个像素计算与所有参考光谱的距离
|
||||
for i in range(n_samples):
|
||||
pixel = X[i]
|
||||
min_distance = float('inf')
|
||||
best_class = 0
|
||||
|
||||
for j in range(n_classes):
|
||||
reference = reference_spectra[j]
|
||||
|
||||
if method == 'jm_distance':
|
||||
distance = self.distance_metrics.jeffreys_matusita_distance(pixel, reference)
|
||||
elif method == 'sid_sa':
|
||||
alpha = self.distance_params[method].get('alpha', 0.5)
|
||||
distance = self.distance_metrics.sid_sa_combined(pixel, reference, alpha)
|
||||
else:
|
||||
# 默认使用欧氏距离
|
||||
distance = self.distance_metrics.euclidean_distance(pixel, reference)
|
||||
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
best_class = j
|
||||
|
||||
labels[i] = best_class
|
||||
|
||||
return labels
|
||||
|
||||
def fit_predict_all_methods(self, X: np.ndarray, reference_spectra: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||
"""使用所有距离度量方法进行分类"""
|
||||
results = {}
|
||||
methods = list(self.vectorized_distance_functions.keys())
|
||||
|
||||
print("开始使用不同距离度量进行分类...")
|
||||
for method in methods:
|
||||
print(f"正在使用 {method} 距离度量...")
|
||||
try:
|
||||
results[method] = self.fit_predict_with_references(X, reference_spectra, method)
|
||||
print(f"✓ {method} 分类完成")
|
||||
except Exception as e:
|
||||
print(f"✗ {method} 分类失败: {e}")
|
||||
results[method] = None
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class HyperspectralImageProcessor:
|
||||
"""高光谱图像处理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.header = None
|
||||
self.wavelengths = None
|
||||
|
||||
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
|
||||
|
||||
try:
|
||||
# 读取ENVI文件
|
||||
img = envi.open(hdr_path)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 提取波长信息
|
||||
if 'wavelength' in header:
|
||||
wavelengths = np.array([float(w) for w in header['wavelength']])
|
||||
else:
|
||||
wavelengths = None
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.wavelengths = wavelengths
|
||||
|
||||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||||
if wavelengths is not None:
|
||||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载图像失败: {e}")
|
||||
|
||||
def parse_roi_xml(self, xml_path: str) -> Dict[str, List[Tuple[float, float]]]:
|
||||
"""
|
||||
解析ENVI ROI XML文件,提取每个区域的坐标
|
||||
|
||||
Parameters:
|
||||
xml_path: XML文件路径
|
||||
|
||||
Returns:
|
||||
字典,键为ROI名称,值为坐标列表 [(x1,y1), (x2,y2), ...]
|
||||
"""
|
||||
try:
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
roi_coordinates = {}
|
||||
|
||||
for region in root.findall('Region'):
|
||||
name = region.get('name')
|
||||
coordinates_text = None
|
||||
|
||||
# 查找Coordinates元素
|
||||
coord_elem = region.find('.//Coordinates')
|
||||
if coord_elem is not None and coord_elem.text:
|
||||
coordinates_text = coord_elem.text.strip()
|
||||
|
||||
if coordinates_text:
|
||||
# 解析坐标字符串
|
||||
coords = []
|
||||
parts = coordinates_text.split()
|
||||
# 坐标是成对的 (x y x y ...)
|
||||
for i in range(0, len(parts), 2):
|
||||
if i + 1 < len(parts):
|
||||
x = float(parts[i])
|
||||
y = float(parts[i + 1])
|
||||
coords.append((x, y))
|
||||
|
||||
roi_coordinates[name] = coords
|
||||
|
||||
return roi_coordinates
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"解析XML文件失败: {e}")
|
||||
|
||||
def extract_reference_spectra(self, data: np.ndarray,
|
||||
roi_coordinates: Dict[str, List[Tuple[float, float]]]) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
从图像中提取参考光谱
|
||||
|
||||
Parameters:
|
||||
data: 高光谱图像数据 (rows, cols, bands)
|
||||
roi_coordinates: ROI坐标字典
|
||||
|
||||
Returns:
|
||||
字典,键为ROI名称,值为平均光谱 (bands,)
|
||||
"""
|
||||
reference_spectra = {}
|
||||
rows, cols, bands = data.shape
|
||||
|
||||
for roi_name, coords in roi_coordinates.items():
|
||||
# 创建多边形
|
||||
polygon = Polygon(coords)
|
||||
|
||||
# 找到多边形内的所有像素
|
||||
pixels_in_roi = []
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
point = Point(j, i) # 注意:ENVI坐标系中x是列,y是行
|
||||
if polygon.contains(point):
|
||||
pixels_in_roi.append((i, j))
|
||||
|
||||
if len(pixels_in_roi) == 0:
|
||||
print(f"警告: ROI '{roi_name}' 中没有找到像素")
|
||||
continue
|
||||
|
||||
# 提取像素光谱
|
||||
spectra = []
|
||||
for i, j in pixels_in_roi:
|
||||
# 修正:确保提取的是一维光谱
|
||||
spectrum = data[i, j, :]
|
||||
# 确保是一维数组
|
||||
if spectrum.ndim > 1:
|
||||
spectrum = spectrum.flatten()
|
||||
|
||||
if np.isfinite(spectrum).all() and spectrum.size == bands: # 只使用有效光谱
|
||||
spectra.append(spectrum)
|
||||
|
||||
if len(spectra) == 0:
|
||||
print(f"警告: ROI '{roi_name}' 中没有有效的光谱数据")
|
||||
continue
|
||||
|
||||
# 计算平均光谱 - 确保结果是一维
|
||||
avg_spectrum = np.mean(spectra, axis=0)
|
||||
if avg_spectrum.ndim > 1:
|
||||
avg_spectrum = avg_spectrum.flatten()
|
||||
|
||||
reference_spectra[roi_name] = avg_spectrum
|
||||
|
||||
print(f"ROI '{roi_name}': 找到 {len(spectra)} 个有效像素, 光谱维度 {avg_spectrum.shape}")
|
||||
|
||||
return reference_spectra
|
||||
|
||||
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
|
||||
"""加载ENVI格式高光谱图像"""
|
||||
try:
|
||||
# 读取ENVI文件
|
||||
img = envi.open(hdr_path)
|
||||
data = img.load()
|
||||
header = dict(img.metadata)
|
||||
|
||||
# 提取波长信息
|
||||
if 'wavelength' in header:
|
||||
wavelengths = np.array([float(w) for w in header['wavelength']])
|
||||
else:
|
||||
wavelengths = None
|
||||
|
||||
self.data = data
|
||||
self.header = header
|
||||
self.wavelengths = wavelengths
|
||||
|
||||
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
|
||||
if wavelengths is not None:
|
||||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||||
|
||||
return data, header
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"加载图像失败: {e}")
|
||||
|
||||
def save_single_band_image(self, data: np.ndarray, output_path: str,
|
||||
method_name: str = "", original_header: Dict = None) -> None:
|
||||
"""保存单波段聚类结果图像为dat和hdr文件"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
# 确保输出文件扩展名为.dat
|
||||
if not output_path.lower().endswith('.dat'):
|
||||
output_path = output_path.rsplit('.', 1)[0] + '.dat'
|
||||
|
||||
# 将数据转换为合适的格式(聚类标签通常使用int32)
|
||||
if data.dtype != np.int32:
|
||||
data_to_save = data.astype(np.int32)
|
||||
else:
|
||||
data_to_save = data
|
||||
|
||||
# 保存为二进制dat文件
|
||||
data_to_save.tofile(output_path)
|
||||
|
||||
# 创建对应的hdr文件
|
||||
hdr_path = output_path.rsplit('.', 1)[0] + '.hdr'
|
||||
self._create_cluster_hdr_file(hdr_path, data_to_save.shape, method_name, original_header)
|
||||
|
||||
print(f"聚类结果已保存到:")
|
||||
print(f" 数据文件: {output_path}")
|
||||
print(f" 头文件: {hdr_path}")
|
||||
print(f"数据类型: {data_to_save.dtype}, 形状: {data_to_save.shape}")
|
||||
|
||||
except Exception as e:
|
||||
raise IOError(f"保存图像失败: {e}")
|
||||
|
||||
def _create_cluster_hdr_file(self, hdr_path: str, data_shape: Tuple,
|
||||
method_name: str, original_header: Dict = None) -> None:
|
||||
"""创建聚类结果的ENVI头文件"""
|
||||
try:
|
||||
# 从原始头文件获取基本信息
|
||||
if original_header is not None:
|
||||
# 复制原始头文件的关键信息
|
||||
samples = original_header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
|
||||
lines = original_header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
|
||||
bands = 1 # 聚类结果是单波段
|
||||
interleave = 'bip'
|
||||
data_type = 3 # ENVI数据类型: 3=int32, 4=float32
|
||||
byte_order = original_header.get('byte order', 0)
|
||||
wavelength_units = original_header.get('wavelength units', 'nm')
|
||||
else:
|
||||
# 默认值(适用于2D聚类结果)
|
||||
if len(data_shape) == 2:
|
||||
lines, samples = data_shape
|
||||
else:
|
||||
lines = data_shape[0]
|
||||
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
|
||||
bands = 1
|
||||
interleave = 'bsq'
|
||||
data_type = 3 # int32
|
||||
byte_order = 0
|
||||
wavelength_units = 'Unknown'
|
||||
|
||||
# 写入hdr文件
|
||||
with open(hdr_path, 'w') as f:
|
||||
f.write("ENVI\n")
|
||||
f.write("description = {\n")
|
||||
f.write(f" 聚类结果 - {method_name}\n")
|
||||
f.write(" 单波段聚类标签图像\n")
|
||||
f.write("}\n")
|
||||
f.write(f"samples = {samples}\n")
|
||||
f.write(f"lines = {lines}\n")
|
||||
f.write(f"bands = {bands}\n")
|
||||
f.write(f"header offset = 0\n")
|
||||
f.write(f"file type = ENVI Standard\n")
|
||||
f.write(f"data type = {data_type}\n")
|
||||
f.write(f"interleave = {interleave}\n")
|
||||
f.write(f"byte order = {byte_order}\n")
|
||||
|
||||
# 如果有波长信息,添加虚拟波长
|
||||
if original_header and 'wavelength' in original_header:
|
||||
f.write("wavelength = {\n")
|
||||
f.write(" 类别标签\n")
|
||||
f.write("}\n")
|
||||
f.write(f"wavelength units = {wavelength_units}\n")
|
||||
|
||||
# 添加类别信息注释
|
||||
f.write("band names = {\n")
|
||||
f.write(f" 聚类结果_{method_name}\n")
|
||||
f.write("}\n")
|
||||
|
||||
print(f"成功创建头文件: {hdr_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建头文件失败: {e}")
|
||||
|
||||
def visualize_clusters(self, cluster_results: Dict[str, np.ndarray],
|
||||
save_path: Optional[str] = None) -> None:
|
||||
"""可视化聚类结果"""
|
||||
n_methods = len(cluster_results)
|
||||
fig, axes = plt.subplots(2, (n_methods + 1) // 2, figsize=(15, 10))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, (method, result) in enumerate(cluster_results.items()):
|
||||
if result is not None and i < len(axes):
|
||||
im = axes[i].imshow(result, cmap='tab10')
|
||||
axes[i].set_title(f'{method.replace("_", " ").title()}')
|
||||
axes[i].axis('off')
|
||||
plt.colorbar(im, ax=axes[i], shrink=0.8)
|
||||
|
||||
# 隐藏多余的子图
|
||||
for j in range(i + 1, len(axes)):
|
||||
axes[j].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"可视化结果已保存到: {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:处理高光谱图像监督分类"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='高光谱图像监督分类分析')
|
||||
parser.add_argument('input_file', help='输入ENVI格式高光谱图像文件 (.hdr)')
|
||||
parser.add_argument('xml_file', help='ENVI ROI XML文件路径')
|
||||
parser.add_argument('--output_dir', '-o', default='output',
|
||||
help='输出目录 (默认: output)')
|
||||
parser.add_argument('--method', '-m', default='all',
|
||||
choices=['all', 'euclidean', 'cosine', 'correlation',
|
||||
'information_divergence', 'jm_distance', 'sid_sa'],
|
||||
help='分类距离度量方法 (默认: all,使用所有方法)')
|
||||
parser.add_argument('--distance_params', '-p', type=str, default=None,
|
||||
help='距离度量方法的超参数,JSON格式字符串,例如: {"jm_distance": {"alpha": 0.7}}')
|
||||
parser.add_argument('--visualize', '-v', action='store_true',
|
||||
help='是否生成可视化结果')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 解析距离参数
|
||||
distance_params = None
|
||||
if args.distance_params:
|
||||
import json
|
||||
try:
|
||||
distance_params = json.loads(args.distance_params)
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告: 无法解析距离参数 '{args.distance_params}',使用默认参数")
|
||||
distance_params = None
|
||||
|
||||
try:
|
||||
# 调用分类函数
|
||||
return run_hsi_classification(
|
||||
input_file=args.input_file,
|
||||
xml_file=args.xml_file,
|
||||
output_dir=args.output_dir,
|
||||
method=args.method,
|
||||
distance_params=distance_params,
|
||||
visualize=args.visualize
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
|
||||
|
||||
def run_hsi_classification(input_file, xml_file, output_dir='output', method='all',
|
||||
distance_params=None, visualize=False):
|
||||
"""
|
||||
执行高光谱图像监督分类分析
|
||||
|
||||
参数:
|
||||
input_file: 输入ENVI格式高光谱图像文件 (.hdr)
|
||||
xml_file: ENVI ROI XML文件路径
|
||||
output_dir: 输出目录 (默认: output)
|
||||
method: 分类距离度量方法 ('all' 或具体方法名,默认: 'all')
|
||||
distance_params: 距离度量方法的超参数字典 (默认: None)
|
||||
visualize: 是否生成可视化结果 (默认: False)
|
||||
|
||||
返回:
|
||||
成功返回0,失败返回1
|
||||
"""
|
||||
try:
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 初始化处理器
|
||||
processor = HyperspectralImageProcessor()
|
||||
|
||||
# 加载图像
|
||||
print(f"加载高光谱图像: {input_file}")
|
||||
data, header = processor.load_image(input_file)
|
||||
|
||||
# 解析XML文件并提取参考光谱
|
||||
print(f"\n解析ROI XML文件: {xml_file}")
|
||||
roi_coordinates = processor.parse_roi_xml(xml_file)
|
||||
|
||||
if not roi_coordinates:
|
||||
raise ValueError("XML文件中没有找到有效的ROI区域")
|
||||
|
||||
print(f"找到 {len(roi_coordinates)} 个ROI区域")
|
||||
for roi_name, coords in roi_coordinates.items():
|
||||
print(f" {roi_name}: {len(coords)} 个顶点")
|
||||
|
||||
# 从图像中提取参考光谱
|
||||
print(f"\n图像数据形状: {data.shape}")
|
||||
print("提取参考光谱...")
|
||||
reference_spectra = processor.extract_reference_spectra(data, roi_coordinates)
|
||||
|
||||
if not reference_spectra:
|
||||
raise ValueError("无法从ROI区域提取有效的参考光谱")
|
||||
|
||||
print(f"成功提取 {len(reference_spectra)} 个参考光谱:")
|
||||
for roi_name, spectrum in reference_spectra.items():
|
||||
print(f" {roi_name}: 形状 {spectrum.shape}, 数据类型 {spectrum.dtype}")
|
||||
|
||||
# 初始化分类器
|
||||
classifier = HyperspectralClassification(distance_params=distance_params)
|
||||
|
||||
# 根据指定方法进行分类
|
||||
print(f"\n开始分类分析 (类别数: {len(reference_spectra)}, 方法: {method})...")
|
||||
if method == 'all':
|
||||
results = classifier.fit_predict_all_methods(data, reference_spectra)
|
||||
else:
|
||||
# 使用指定方法
|
||||
result = classifier.fit_predict_with_references(data, reference_spectra, method)
|
||||
results = {method: result}
|
||||
|
||||
# 保存结果
|
||||
print("\n保存分类结果...")
|
||||
saved_files = []
|
||||
for method, result in results.items():
|
||||
if result is not None:
|
||||
output_file = os.path.join(output_dir, f'classification_{method}.dat')
|
||||
processor.save_single_band_image(result, output_file, method, header)
|
||||
saved_files.append((method, output_file))
|
||||
|
||||
# 输出统计信息
|
||||
print("\n=== 分类统计信息 ===")
|
||||
for method, result in results.items():
|
||||
if result is not None:
|
||||
unique_labels = np.unique(result[result >= 0])
|
||||
n_classes_found = len(unique_labels)
|
||||
print(f"\n{method}: 识别 {n_classes_found} 个类别")
|
||||
|
||||
# 显示每个类别的像素数量
|
||||
for class_idx, class_name in enumerate(classifier.class_names_):
|
||||
pixel_count = np.sum(result == class_idx)
|
||||
print(f" {class_name}: {pixel_count} 个像素")
|
||||
|
||||
print("\n✓ 分类分析完成!")
|
||||
print(f"输出目录: {output_dir}")
|
||||
print(f"生成文件: {len(saved_files)} 个")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user