增加模块;增加主调用命令

This commit is contained in:
2026-01-07 16:36:47 +08:00
commit 2d4b170a45
109 changed files with 55763 additions and 0 deletions

View File

@ -0,0 +1,815 @@
import numpy as np
from scipy import stats
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import spectral as spy
from spectral import envi
import matplotlib.pyplot as plt
import os
from typing import Tuple, List, Dict, Optional, Union
import warnings
import xml.etree.ElementTree as ET
from shapely.geometry import Polygon, Point
import warnings
warnings.filterwarnings('ignore')
class HyperspectralDistanceMetrics:
"""高光谱数据距离度量类"""
@staticmethod
def cosine_similarity(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""余弦相似度"""
dot_product = np.dot(spectrum1, spectrum2)
norm1 = np.linalg.norm(spectrum1)
norm2 = np.linalg.norm(spectrum2)
return dot_product / (norm1 * norm2) if norm1 != 0 and norm2 != 0 else 0
@staticmethod
def euclidean_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""欧氏距离"""
return np.linalg.norm(spectrum1 - spectrum2)
@staticmethod
def information_divergence(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""信息散度 (Kullback-Leibler divergence)"""
# 添加小常数避免除零
eps = 1e-10
spectrum1 = spectrum1 + eps
spectrum2 = spectrum2 + eps
# 归一化为概率分布
p = spectrum1 / np.sum(spectrum1)
q = spectrum2 / np.sum(spectrum2)
# 计算KL散度
kl_pq = np.sum(p * np.log(p / q))
kl_qp = np.sum(q * np.log(q / p))
return (kl_pq + kl_qp) / 2
@staticmethod
def correlation_coefficient(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""相关系数"""
return np.corrcoef(spectrum1, spectrum2)[0, 1]
@staticmethod
def jeffreys_matusita_distance(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""J-M距离 (Jeffreys-Matusita Distance)"""
eps = 1e-10
# 转换为概率分布
p = (spectrum1 + eps) / np.sum(spectrum1 + eps)
q = (spectrum2 + eps) / np.sum(spectrum2 + eps)
# Bhattacharyya系数
bc = np.sum(np.sqrt(p * q))
# J-M距离 = sqrt(2 * (1 - bc))
jm_distance = np.sqrt(2 * (1 - bc))
return jm_distance
@staticmethod
def spectral_angle_mapper(spectrum1: np.ndarray, spectrum2: np.ndarray) -> float:
"""光谱角映射器 (Spectral Angle Mapper)"""
cos_sim = HyperspectralDistanceMetrics.cosine_similarity(spectrum1, spectrum2)
# 转换为角度(弧度)
angle = np.arccos(np.clip(cos_sim, -1, 1))
return angle
@staticmethod
def sid_sa_combined(spectrum1: np.ndarray, spectrum2: np.ndarray,
alpha: float = 0.5) -> float:
"""SID_SA结合法 (Spectral Information Divergence + Spectral Angle Mapper)"""
sid = HyperspectralDistanceMetrics.information_divergence(spectrum1, spectrum2)
sam = HyperspectralDistanceMetrics.spectral_angle_mapper(spectrum1, spectrum2)
# 标准化并结合
sid_norm = sid / (1 + sid) # 归一化到[0,1)
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
return alpha * sid_norm + (1 - alpha) * sam_norm
# ===== 向量化距离计算函数 =====
@staticmethod
def vectorized_cosine_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化余弦距离计算"""
# X: (n_samples, n_features), centers: (n_clusters, n_features)
# 返回: (n_samples, n_clusters)
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
# 避免除零
X_norm = np.where(X_norm == 0, 1, X_norm)
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
# 计算余弦相似度
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
# 转换为距离 (1 - 相似度)
return 1 - np.clip(similarity, -1, 1)
@staticmethod
def vectorized_euclidean_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化欧氏距离计算"""
# 使用广播计算欧氏距离
# X: (n_samples, n_features), centers: (n_clusters, n_features)
# 返回: (n_samples, n_clusters)
diff = X[:, np.newaxis, :] - centers[np.newaxis, :, :]
return np.linalg.norm(diff, axis=2)
@staticmethod
def vectorized_correlation_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化相关系数距离计算"""
# 计算相关系数矩阵
corr_matrix = np.corrcoef(X.T, centers.T)[:len(X), len(X):]
# 转换为距离 (1 - |相关系数|)
return 1 - np.abs(corr_matrix)
@staticmethod
def vectorized_information_divergence(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化信息散度计算"""
eps = 1e-10
X_norm = X + eps
centers_norm = centers + eps
# 归一化为概率分布
X_prob = X_norm / np.sum(X_norm, axis=1, keepdims=True)
centers_prob = centers_norm / np.sum(centers_norm, axis=1, keepdims=True)
# 计算KL散度
# 使用广播: X_prob[:, np.newaxis, :] - centers_prob[np.newaxis, :, :]
kl_pq = np.sum(X_prob[:, np.newaxis, :] * np.log(X_prob[:, np.newaxis, :] / centers_prob[np.newaxis, :, :]), axis=2)
kl_qp = np.sum(centers_prob[np.newaxis, :, :] * np.log(centers_prob[np.newaxis, :, :] / X_prob[:, np.newaxis, :]), axis=2)
return (kl_pq + kl_qp) / 2
@staticmethod
def vectorized_jeffreys_matusita_distance(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化J-M距离计算"""
eps = 1e-10
# 转换为概率分布
X_prob = (X + eps) / np.sum(X + eps, axis=1, keepdims=True)
centers_prob = (centers + eps) / np.sum(centers + eps, axis=1, keepdims=True)
# 计算Bhattacharyya系数
bc = np.sum(np.sqrt(X_prob[:, np.newaxis, :] * centers_prob[np.newaxis, :, :]), axis=2)
# J-M距离
return np.sqrt(2 * (1 - bc))
@staticmethod
def vectorized_sid_sa_combined(X: np.ndarray, centers: np.ndarray, alpha: float = 0.5) -> np.ndarray:
"""向量化SID_SA结合距离计算"""
# 计算SID距离
sid = HyperspectralDistanceMetrics.vectorized_information_divergence(X, centers)
# 计算SAM距离
sam = HyperspectralDistanceMetrics.vectorized_spectral_angle(X, centers)
# 归一化
sid_norm = sid / (1 + sid) # 归一化到[0,1)
sam_norm = sam / (np.pi / 2) # 归一化到[0,1]
# 结合
return alpha * sid_norm + (1 - alpha) * sam_norm
@staticmethod
def vectorized_spectral_angle(X: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""向量化光谱角距离计算"""
# 计算余弦相似度
X_norm = np.linalg.norm(X, axis=1, keepdims=True)
centers_norm = np.linalg.norm(centers, axis=1, keepdims=True)
X_norm = np.where(X_norm == 0, 1, X_norm)
centers_norm = np.where(centers_norm == 0, 1, centers_norm)
similarity = np.dot(X, centers.T) / (X_norm * centers_norm.T)
similarity = np.clip(similarity, -1, 1)
# 转换为角度
return np.arccos(similarity)
class HyperspectralClassification:
"""高光谱图像监督分类器"""
def __init__(self, random_state: int = 42, distance_params: Dict[str, Dict] = None):
self.random_state = random_state
self.distance_metrics = HyperspectralDistanceMetrics()
self.reference_spectra_ = None
self.labels_ = None
self.class_names_ = None
# 默认距离度量参数
self.default_distance_params = {
'cosine': {},
'euclidean': {},
'correlation': {},
'information_divergence': {},
'sam': {},
'jm_distance': {'alpha': 0.5}, # Bhattacharyya系数权重
'sid_sa': {'alpha': 0.5} # SID和SAM的权重
}
# 更新用户提供的参数
self.distance_params = self.default_distance_params.copy()
if distance_params:
self.distance_params.update(distance_params)
# 更新向量化距离函数,添加新的方法
self.vectorized_distance_functions = {
'cosine': self.distance_metrics.vectorized_cosine_distance,
'euclidean': self.distance_metrics.vectorized_euclidean_distance,
'correlation': self.distance_metrics.vectorized_correlation_distance,
'information_divergence': self.distance_metrics.vectorized_information_divergence,
'sam': self.distance_metrics.vectorized_spectral_angle,
'jm_distance': self.distance_metrics.vectorized_jeffreys_matusita_distance, # 新增
'sid_sa': self.distance_metrics.vectorized_sid_sa_combined # 新增
}
def fit_predict_with_references(self, X: np.ndarray,
reference_spectra: Dict[str, np.ndarray],
method: str = 'euclidean') -> np.ndarray:
"""
使用参考光谱进行监督分类
Parameters:
X: 高光谱数据,形状为 (n_samples, n_bands) 或 (rows, cols, n_bands)
reference_spectra: 参考光谱字典,键为类别名,值为光谱 (bands,)
method: 距离度量方法
Returns:
分类结果形状与X相同
"""
if method not in self.vectorized_distance_functions:
raise ValueError(f"不支持的距离度量方法: {method}")
if not reference_spectra:
raise ValueError("必须提供参考光谱")
print(f"DEBUG fit_predict_with_references: X.shape = {X.shape}")
print(f"DEBUG fit_predict_with_references: reference_spectra keys = {list(reference_spectra.keys())}")
for name, spectrum in reference_spectra.items():
print(f"DEBUG fit_predict_with_references: {name}.shape = {spectrum.shape}")
# 处理3D输入 (图像数据)
if X.ndim == 3:
rows, cols, bands = X.shape
X_reshaped = X.reshape(-1, bands)
# 移除NaN和无穷大值
valid_mask = np.isfinite(X_reshaped).all(axis=1)
X_valid = X_reshaped[valid_mask]
else:
X_reshaped = X
X_valid = X_reshaped
rows, cols = None, None
if len(X_valid) == 0:
raise ValueError("没有有效的光谱数据")
# 转换为参考光谱数组 - 确保是二维数组 (n_classes, n_bands)
class_names = list(reference_spectra.keys())
# 首先检查并确保每个参考光谱都是一维的
cleaned_reference_spectra = []
for name in class_names:
spectrum = reference_spectra[name]
if spectrum.ndim > 1:
spectrum = spectrum.flatten()
cleaned_reference_spectra.append(spectrum)
reference_array = np.array(cleaned_reference_spectra) # 现在应该是 (n_classes, n_bands)
print(f"DEBUG fit_predict_with_references: reference_array.shape = {reference_array.shape}")
# 进行分类
labels = self._classify_pixels(X_valid, reference_array, method)
# 如果是3D输入需要重塑回图像形状
if X.ndim == 3:
result = np.full((rows * cols,), -1, dtype=int)
result[valid_mask] = labels
result = result.reshape(rows, cols)
else:
result = labels
self.reference_spectra_ = reference_spectra
self.class_names_ = class_names
self.labels_ = result
return result
def _classify_pixels(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
"""
使用参考光谱对像素进行分类
Parameters:
X: 像素光谱数据 (n_samples, n_bands)
reference_spectra: 参考光谱 (n_classes, n_bands)
method: 距离度量方法
Returns:
分类标签 (n_samples,)
"""
print(f"DEBUG: X.shape = {X.shape}")
print(f"DEBUG: reference_spectra.shape = {reference_spectra.shape}")
print(f"DEBUG: method = {method}")
# 检查是否有向量化距离函数
if method not in self.vectorized_distance_functions or self.vectorized_distance_functions[method] is None:
raise ValueError(f"不支持向量化计算的距离度量方法: {method}")
# 使用向量化距离计算
distances = self.vectorized_distance_functions[method](X, reference_spectra)
# 为每个像素分配最近的参考光谱类别
labels = np.argmin(distances, axis=1)
return labels
def _classify_with_complex_distance(self, X: np.ndarray, reference_spectra: np.ndarray, method: str) -> np.ndarray:
"""
使用复杂距离度量进行分类(非向量化)
"""
print(f"使用 {method} 距离度量进行分类...")
n_samples = X.shape[0]
n_classes = reference_spectra.shape[0]
labels = np.zeros(n_samples, dtype=int)
# 为每个像素计算与所有参考光谱的距离
for i in range(n_samples):
pixel = X[i]
min_distance = float('inf')
best_class = 0
for j in range(n_classes):
reference = reference_spectra[j]
if method == 'jm_distance':
distance = self.distance_metrics.jeffreys_matusita_distance(pixel, reference)
elif method == 'sid_sa':
alpha = self.distance_params[method].get('alpha', 0.5)
distance = self.distance_metrics.sid_sa_combined(pixel, reference, alpha)
else:
# 默认使用欧氏距离
distance = self.distance_metrics.euclidean_distance(pixel, reference)
if distance < min_distance:
min_distance = distance
best_class = j
labels[i] = best_class
return labels
def fit_predict_all_methods(self, X: np.ndarray, reference_spectra: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""使用所有距离度量方法进行分类"""
results = {}
methods = list(self.vectorized_distance_functions.keys())
print("开始使用不同距离度量进行分类...")
for method in methods:
print(f"正在使用 {method} 距离度量...")
try:
results[method] = self.fit_predict_with_references(X, reference_spectra, method)
print(f"{method} 分类完成")
except Exception as e:
print(f"{method} 分类失败: {e}")
results[method] = None
return results
class HyperspectralImageProcessor:
"""高光谱图像处理类"""
def __init__(self):
self.data = None
self.header = None
self.wavelengths = None
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
try:
# 读取ENVI文件
img = envi.open(hdr_path)
data = img.load()
header = dict(img.metadata)
# 提取波长信息
if 'wavelength' in header:
wavelengths = np.array([float(w) for w in header['wavelength']])
else:
wavelengths = None
self.data = data
self.header = header
self.wavelengths = wavelengths
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
if wavelengths is not None:
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def parse_roi_xml(self, xml_path: str) -> Dict[str, List[Tuple[float, float]]]:
"""
解析ENVI ROI XML文件提取每个区域的坐标
Parameters:
xml_path: XML文件路径
Returns:
字典键为ROI名称值为坐标列表 [(x1,y1), (x2,y2), ...]
"""
try:
tree = ET.parse(xml_path)
root = tree.getroot()
roi_coordinates = {}
for region in root.findall('Region'):
name = region.get('name')
coordinates_text = None
# 查找Coordinates元素
coord_elem = region.find('.//Coordinates')
if coord_elem is not None and coord_elem.text:
coordinates_text = coord_elem.text.strip()
if coordinates_text:
# 解析坐标字符串
coords = []
parts = coordinates_text.split()
# 坐标是成对的 (x y x y ...)
for i in range(0, len(parts), 2):
if i + 1 < len(parts):
x = float(parts[i])
y = float(parts[i + 1])
coords.append((x, y))
roi_coordinates[name] = coords
return roi_coordinates
except Exception as e:
raise IOError(f"解析XML文件失败: {e}")
def extract_reference_spectra(self, data: np.ndarray,
roi_coordinates: Dict[str, List[Tuple[float, float]]]) -> Dict[str, np.ndarray]:
"""
从图像中提取参考光谱
Parameters:
data: 高光谱图像数据 (rows, cols, bands)
roi_coordinates: ROI坐标字典
Returns:
字典键为ROI名称值为平均光谱 (bands,)
"""
reference_spectra = {}
rows, cols, bands = data.shape
for roi_name, coords in roi_coordinates.items():
# 创建多边形
polygon = Polygon(coords)
# 找到多边形内的所有像素
pixels_in_roi = []
for i in range(rows):
for j in range(cols):
point = Point(j, i) # 注意ENVI坐标系中x是列y是行
if polygon.contains(point):
pixels_in_roi.append((i, j))
if len(pixels_in_roi) == 0:
print(f"警告: ROI '{roi_name}' 中没有找到像素")
continue
# 提取像素光谱
spectra = []
for i, j in pixels_in_roi:
# 修正:确保提取的是一维光谱
spectrum = data[i, j, :]
# 确保是一维数组
if spectrum.ndim > 1:
spectrum = spectrum.flatten()
if np.isfinite(spectrum).all() and spectrum.size == bands: # 只使用有效光谱
spectra.append(spectrum)
if len(spectra) == 0:
print(f"警告: ROI '{roi_name}' 中没有有效的光谱数据")
continue
# 计算平均光谱 - 确保结果是一维
avg_spectrum = np.mean(spectra, axis=0)
if avg_spectrum.ndim > 1:
avg_spectrum = avg_spectrum.flatten()
reference_spectra[roi_name] = avg_spectrum
print(f"ROI '{roi_name}': 找到 {len(spectra)} 个有效像素, 光谱维度 {avg_spectrum.shape}")
return reference_spectra
def load_image(self, hdr_path: str) -> Tuple[np.ndarray, Dict]:
"""加载ENVI格式高光谱图像"""
try:
# 读取ENVI文件
img = envi.open(hdr_path)
data = img.load()
header = dict(img.metadata)
# 提取波长信息
if 'wavelength' in header:
wavelengths = np.array([float(w) for w in header['wavelength']])
else:
wavelengths = None
self.data = data
self.header = header
self.wavelengths = wavelengths
print(f"成功加载图像: 形状={data.shape}, 数据类型={data.dtype}")
if wavelengths is not None:
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
return data, header
except Exception as e:
raise IOError(f"加载图像失败: {e}")
def save_single_band_image(self, data: np.ndarray, output_path: str,
method_name: str = "", original_header: Dict = None) -> None:
"""保存单波段聚类结果图像为dat和hdr文件"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 确保输出文件扩展名为.dat
if not output_path.lower().endswith('.dat'):
output_path = output_path.rsplit('.', 1)[0] + '.dat'
# 将数据转换为合适的格式聚类标签通常使用int32
if data.dtype != np.int32:
data_to_save = data.astype(np.int32)
else:
data_to_save = data
# 保存为二进制dat文件
data_to_save.tofile(output_path)
# 创建对应的hdr文件
hdr_path = output_path.rsplit('.', 1)[0] + '.hdr'
self._create_cluster_hdr_file(hdr_path, data_to_save.shape, method_name, original_header)
print(f"聚类结果已保存到:")
print(f" 数据文件: {output_path}")
print(f" 头文件: {hdr_path}")
print(f"数据类型: {data_to_save.dtype}, 形状: {data_to_save.shape}")
except Exception as e:
raise IOError(f"保存图像失败: {e}")
def _create_cluster_hdr_file(self, hdr_path: str, data_shape: Tuple,
method_name: str, original_header: Dict = None) -> None:
"""创建聚类结果的ENVI头文件"""
try:
# 从原始头文件获取基本信息
if original_header is not None:
# 复制原始头文件的关键信息
samples = original_header.get('samples', data_shape[1] if len(data_shape) > 1 else data_shape[0])
lines = original_header.get('lines', data_shape[0] if len(data_shape) > 1 else 1)
bands = 1 # 聚类结果是单波段
interleave = 'bip'
data_type = 3 # ENVI数据类型: 3=int32, 4=float32
byte_order = original_header.get('byte order', 0)
wavelength_units = original_header.get('wavelength units', 'nm')
else:
# 默认值适用于2D聚类结果
if len(data_shape) == 2:
lines, samples = data_shape
else:
lines = data_shape[0]
samples = data_shape[1] if len(data_shape) > 1 else data_shape[0]
bands = 1
interleave = 'bsq'
data_type = 3 # int32
byte_order = 0
wavelength_units = 'Unknown'
# 写入hdr文件
with open(hdr_path, 'w') as f:
f.write("ENVI\n")
f.write("description = {\n")
f.write(f" 聚类结果 - {method_name}\n")
f.write(" 单波段聚类标签图像\n")
f.write("}\n")
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write(f"header offset = 0\n")
f.write(f"file type = ENVI Standard\n")
f.write(f"data type = {data_type}\n")
f.write(f"interleave = {interleave}\n")
f.write(f"byte order = {byte_order}\n")
# 如果有波长信息,添加虚拟波长
if original_header and 'wavelength' in original_header:
f.write("wavelength = {\n")
f.write(" 类别标签\n")
f.write("}\n")
f.write(f"wavelength units = {wavelength_units}\n")
# 添加类别信息注释
f.write("band names = {\n")
f.write(f" 聚类结果_{method_name}\n")
f.write("}\n")
print(f"成功创建头文件: {hdr_path}")
except Exception as e:
print(f"创建头文件失败: {e}")
def visualize_clusters(self, cluster_results: Dict[str, np.ndarray],
save_path: Optional[str] = None) -> None:
"""可视化聚类结果"""
n_methods = len(cluster_results)
fig, axes = plt.subplots(2, (n_methods + 1) // 2, figsize=(15, 10))
axes = axes.flatten()
for i, (method, result) in enumerate(cluster_results.items()):
if result is not None and i < len(axes):
im = axes[i].imshow(result, cmap='tab10')
axes[i].set_title(f'{method.replace("_", " ").title()}')
axes[i].axis('off')
plt.colorbar(im, ax=axes[i], shrink=0.8)
# 隐藏多余的子图
for j in range(i + 1, len(axes)):
axes[j].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"可视化结果已保存到: {save_path}")
plt.show()
def main():
"""主函数:处理高光谱图像监督分类"""
import argparse
parser = argparse.ArgumentParser(description='高光谱图像监督分类分析')
parser.add_argument('input_file', help='输入ENVI格式高光谱图像文件 (.hdr)')
parser.add_argument('xml_file', help='ENVI ROI XML文件路径')
parser.add_argument('--output_dir', '-o', default='output',
help='输出目录 (默认: output)')
parser.add_argument('--method', '-m', default='all',
choices=['all', 'euclidean', 'cosine', 'correlation',
'information_divergence', 'jm_distance', 'sid_sa'],
help='分类距离度量方法 (默认: all使用所有方法)')
parser.add_argument('--distance_params', '-p', type=str, default=None,
help='距离度量方法的超参数JSON格式字符串例如: {"jm_distance": {"alpha": 0.7}}')
parser.add_argument('--visualize', '-v', action='store_true',
help='是否生成可视化结果')
args = parser.parse_args()
# 解析距离参数
distance_params = None
if args.distance_params:
import json
try:
distance_params = json.loads(args.distance_params)
except json.JSONDecodeError:
print(f"警告: 无法解析距离参数 '{args.distance_params}',使用默认参数")
distance_params = None
try:
# 调用分类函数
return run_hsi_classification(
input_file=args.input_file,
xml_file=args.xml_file,
output_dir=args.output_dir,
method=args.method,
distance_params=distance_params,
visualize=args.visualize
)
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
def run_hsi_classification(input_file, xml_file, output_dir='output', method='all',
distance_params=None, visualize=False):
"""
执行高光谱图像监督分类分析
参数:
input_file: 输入ENVI格式高光谱图像文件 (.hdr)
xml_file: ENVI ROI XML文件路径
output_dir: 输出目录 (默认: output)
method: 分类距离度量方法 ('all' 或具体方法名,默认: 'all')
distance_params: 距离度量方法的超参数字典 (默认: None)
visualize: 是否生成可视化结果 (默认: False)
返回:
成功返回0失败返回1
"""
try:
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 初始化处理器
processor = HyperspectralImageProcessor()
# 加载图像
print(f"加载高光谱图像: {input_file}")
data, header = processor.load_image(input_file)
# 解析XML文件并提取参考光谱
print(f"\n解析ROI XML文件: {xml_file}")
roi_coordinates = processor.parse_roi_xml(xml_file)
if not roi_coordinates:
raise ValueError("XML文件中没有找到有效的ROI区域")
print(f"找到 {len(roi_coordinates)} 个ROI区域")
for roi_name, coords in roi_coordinates.items():
print(f" {roi_name}: {len(coords)} 个顶点")
# 从图像中提取参考光谱
print(f"\n图像数据形状: {data.shape}")
print("提取参考光谱...")
reference_spectra = processor.extract_reference_spectra(data, roi_coordinates)
if not reference_spectra:
raise ValueError("无法从ROI区域提取有效的参考光谱")
print(f"成功提取 {len(reference_spectra)} 个参考光谱:")
for roi_name, spectrum in reference_spectra.items():
print(f" {roi_name}: 形状 {spectrum.shape}, 数据类型 {spectrum.dtype}")
# 初始化分类器
classifier = HyperspectralClassification(distance_params=distance_params)
# 根据指定方法进行分类
print(f"\n开始分类分析 (类别数: {len(reference_spectra)}, 方法: {method})...")
if method == 'all':
results = classifier.fit_predict_all_methods(data, reference_spectra)
else:
# 使用指定方法
result = classifier.fit_predict_with_references(data, reference_spectra, method)
results = {method: result}
# 保存结果
print("\n保存分类结果...")
saved_files = []
for method, result in results.items():
if result is not None:
output_file = os.path.join(output_dir, f'classification_{method}.dat')
processor.save_single_band_image(result, output_file, method, header)
saved_files.append((method, output_file))
# 输出统计信息
print("\n=== 分类统计信息 ===")
for method, result in results.items():
if result is not None:
unique_labels = np.unique(result[result >= 0])
n_classes_found = len(unique_labels)
print(f"\n{method}: 识别 {n_classes_found} 个类别")
# 显示每个类别的像素数量
for class_idx, class_name in enumerate(classifier.class_names_):
pixel_count = np.sum(result == class_idx)
print(f" {class_name}: {pixel_count} 个像素")
print("\n✓ 分类分析完成!")
print(f"输出目录: {output_dir}")
print(f"生成文件: {len(saved_files)}")
return 0
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == '__main__':
exit(main())