import os import sys import argparse import json import ast import time import numpy as np from osgeo import gdal import xml.etree.ElementTree as ET from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.preprocessing import StandardScaler, LabelEncoder from sklearn.svm import SVC from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier from sklearn.linear_model import LogisticRegression from sklearn.neighbors import KNeighborsClassifier from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis from sklearn.cross_decomposition import PLSRegression from sklearn.decomposition import PCA from sklearn.metrics import classification_report, confusion_matrix, accuracy_score from sklearn.cluster import KMeans from scipy.spatial.distance import cdist from scipy.optimize import nnls import joblib from tqdm import tqdm import warnings from typing import Optional, Dict, Any, Tuple, List from dataclasses import dataclass, field # 可视化和数据保存导入 import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import precision_recall_fscore_support # 导入梯度提升模型 try: from xgboost import XGBClassifier XGBOOST_AVAILABLE = True except ImportError: XGBOOST_AVAILABLE = False print("警告: XGBoost未安装,将跳过XGBoost模型") try: from lightgbm import LGBMClassifier LIGHTGBM_AVAILABLE = True except ImportError: LIGHTGBM_AVAILABLE = False print("警告: LightGBM未安装,将跳过LightGBM模型") try: from catboost import CatBoostClassifier CATBOOST_AVAILABLE = True except ImportError: CATBOOST_AVAILABLE = False print("警告: CatBoost未安装,将跳过CatBoost模型") warnings.filterwarnings('ignore') @dataclass class ClassificationConfig: """高光谱图像分类配置类""" # 输入文件配置 hyperspectral_path: Optional[str] = None roi_path: Optional[str] = None # 输出配置 output_path: Optional[str] = None # 模型配置 model_type: str = 'svm' use_grid_search: bool = True model_params: Dict[str, Any] = field(default_factory=dict) # 预测配置 batch_size: int = 10000 # 数据预处理配置 use_standardization: bool = True # 数据分割配置 test_size: float = 0.2 def __post_init__(self): """参数校验和默认值设置""" # 校验必需的文件路径 if not self.hyperspectral_path: raise ValueError("必须指定高光谱文件路径(hyperspectral_path)") if not self.roi_path: raise ValueError("必须指定ROI文件路径(roi_path)") if not self.output_path: raise ValueError("必须指定输出文件路径(output_path)") # 校验文件存在性 if not os.path.exists(self.hyperspectral_path): raise FileNotFoundError(f"高光谱文件不存在: {self.hyperspectral_path}") if not os.path.exists(self.roi_path): raise FileNotFoundError(f"ROI文件不存在: {self.roi_path}") # 校验模型类型 supported_models = self._get_supported_models() if self.model_type not in supported_models: raise ValueError(f"不支持的模型类型: {self.model_type}。支持的模型: {list(supported_models.keys())}") # 设置模型默认参数 self._set_model_default_params() def _get_supported_models(self) -> Dict[str, str]: """获取支持的模型列表""" models = { 'euclidean': 'Euclidean Distance', 'mahalanobis': 'Mahalanobis Distance', 'linear_discriminant': 'Linear Discriminant Analysis', 'quadratic_discriminant': 'Quadratic Discriminant Analysis', 'plsda': 'PLS Discriminant Analysis', 'spectral_angle_mapper': 'Spectral Angle Mapper', 'spectral_unmix': 'Spectral Unmixing', 'random_forest': 'Random Forest', 'svm': 'Support Vector Machine', 'logistic_regression': 'Logistic Regression', 'knn': 'K-Nearest Neighbors', 'adaboost': 'AdaBoost' } # 条件添加梯度提升模型 if XGBOOST_AVAILABLE: models['xgboost'] = 'XGBoost' if LIGHTGBM_AVAILABLE: models['lightgbm'] = 'LightGBM' if CATBOOST_AVAILABLE: models['catboost'] = 'CatBoost' return models def _set_model_default_params(self): """根据模型类型设置默认参数""" if self.model_type == 'svm': self.model_params.setdefault('C', 1.0) self.model_params.setdefault('kernel', 'rbf') self.model_params.setdefault('gamma', 'scale') elif self.model_type == 'random_forest': self.model_params.setdefault('n_estimators', 100) self.model_params.setdefault('max_depth', None) self.model_params.setdefault('min_samples_split', 2) elif self.model_type == 'logistic_regression': self.model_params.setdefault('C', 1.0) self.model_params.setdefault('max_iter', 1000) elif self.model_type == 'knn': self.model_params.setdefault('n_neighbors', 5) self.model_params.setdefault('weights', 'uniform') elif self.model_type == 'spectral_unmix': self.model_params.setdefault('n_endmembers', 10) elif self.model_type == 'adaboost': self.model_params.setdefault('n_estimators', 50) self.model_params.setdefault('learning_rate', 1.0) # 为梯度提升模型设置默认参数 if self.model_type == 'xgboost' and XGBOOST_AVAILABLE: self.model_params.setdefault('n_estimators', 100) self.model_params.setdefault('max_depth', 6) self.model_params.setdefault('learning_rate', 0.1) elif self.model_type == 'lightgbm' and LIGHTGBM_AVAILABLE: self.model_params.setdefault('n_estimators', 100) self.model_params.setdefault('max_depth', -1) self.model_params.setdefault('learning_rate', 0.1) elif self.model_type == 'catboost' and CATBOOST_AVAILABLE: self.model_params.setdefault('iterations', 100) self.model_params.setdefault('depth', 6) self.model_params.setdefault('learning_rate', 0.1) class ConfigurableHyperspectralClassifier: """可配置的高光谱图像分类器""" def __init__(self, config: ClassificationConfig): self.config = config # 初始化数据存储 self.data = None self.geotransform = None self.projection = None self.roi_data = {} self.scaler = StandardScaler() if config.use_standardization else None self.label_encoder = LabelEncoder() self.model = None self.class_names = [] # 注册所有可用的分类器 self.classifiers = { 'euclidean': EuclideanDistanceClassifier(), 'mahalanobis': MahalanobisDistanceClassifier(), 'linear_discriminant': LinearDiscriminantAnalysis(), 'quadratic_discriminant': QuadraticDiscriminantAnalysis(), 'plsda': PLSDAClassifier(), 'spectral_angle_mapper': SpectralAngleMapper(), 'spectral_unmix': SpectralUnmixClassifier(), 'random_forest': RandomForestClassifier(random_state=42), 'svm': SVC(random_state=42, probability=True), 'logistic_regression': LogisticRegression(random_state=42, max_iter=1000), 'knn': KNeighborsClassifier(), 'adaboost': AdaBoostClassifier(random_state=42) } # 条件添加梯度提升模型 if XGBOOST_AVAILABLE: self.classifiers['xgboost'] = XGBClassifier( random_state=42, use_label_encoder=False, eval_metric='logloss', **config.model_params ) if LIGHTGBM_AVAILABLE: self.classifiers['lightgbm'] = LGBMClassifier(random_state=42, verbose=-1) if CATBOOST_AVAILABLE: self.classifiers['catboost'] = CatBoostClassifier(random_state=42, verbose=False) def read_hyperspectral_data(self) -> None: """ 读取高光谱数据文件 """ print(f"正在读取高光谱文件: {self.config.hyperspectral_path}") # 使用GDAL读取数据 dataset = gdal.Open(self.config.hyperspectral_path) if dataset is None: raise ValueError(f"无法打开文件: {self.config.hyperspectral_path}") # 获取基本信息 cols = dataset.RasterXSize rows = dataset.RasterYSize bands = dataset.RasterCount print(f"图像尺寸: {cols} x {rows} x {bands}") # 读取所有波段数据 self.data = np.zeros((rows, cols, bands), dtype=np.float32) for band_idx in tqdm(range(bands), desc="读取波段"): band = dataset.GetRasterBand(band_idx + 1) band_data = band.ReadAsArray() self.data[:, :, band_idx] = band_data.astype(np.float32) # 获取地理信息 self.geotransform = dataset.GetGeoTransform() self.projection = dataset.GetProjection() dataset = None print("高光谱数据读取完成") def parse_roi_xml(self) -> None: """ 解析XML ROI文件 """ print(f"正在解析ROI文件: {self.config.roi_path}") tree = ET.parse(self.config.roi_path) root = tree.getroot() self.roi_data = {} self.class_names = [] # 支持两种不同的路径查找方式 for region in root.findall('.//Region'): name = region.get('name') color = region.get('color') coords = [] # 尝试第一种格式:单独的Point元素 for point in region.findall('.//Point'): x_attr = point.get('x') y_attr = point.get('y') if x_attr is not None and y_attr is not None: try: x = float(x_attr) y = float(y_attr) coords.append((x, y)) except ValueError: pass # 如果通过Point元素没有找到坐标,尝试第二种格式:Coordinates文本 if not coords: coords_element = region.find('.//Coordinates') if coords_element is not None and coords_element.text: coords_text = coords_element.text.strip() numbers = coords_text.split() if len(numbers) % 2 == 0: try: for i in range(0, len(numbers), 2): x = float(numbers[i]) y = float(numbers[i + 1]) coords.append((x, y)) except ValueError: coords = [] # 解析失败,清空坐标 if coords: self.roi_data[name] = { 'coords': coords, 'color': color } self.class_names.append(name) print(f"成功解析区域 '{name}',包含 {len(coords)} 个坐标点") else: print(f"警告: 区域 '{name}' 没有有效的坐标数据") if self.roi_data: print(f"解析完成,共找到 {len(self.class_names)} 个区域: {self.class_names}") else: print("警告: 没有找到任何有效的ROI区域数据") def extract_training_samples(self) -> Tuple[np.ndarray, np.ndarray]: """ 从ROI区域提取训练样本 Returns: X_train: 训练特征 (n_samples, n_bands) y_train: 训练标签 (n_samples,) """ print("正在提取训练样本...") if self.data is None: raise ValueError("请先读取高光谱数据") if not self.roi_data: raise ValueError("请先解析ROI文件") rows, cols, bands = self.data.shape X_samples = [] y_samples = [] for class_name, region_data in self.roi_data.items(): print(f"处理类别 '{class_name}'...") coords = region_data['coords'] # 将坐标转换为像素位置 pixel_coords = [] for x, y in coords: # 如果有地理变换矩阵,使用它进行坐标转换 if self.geotransform: # 从地理坐标转换为像素坐标 # GeoTransform: [top_left_x, pixel_width, 0, top_left_y, 0, pixel_height] col = int((x - self.geotransform[0]) / self.geotransform[1]) row = int((y - self.geotransform[3]) / self.geotransform[5]) else: # 假设坐标直接对应像素位置(列, 行) col = int(x) row = int(y) # 确保坐标在图像范围内 if 0 <= row < rows and 0 <= col < cols: pixel_coords.append((row, col)) # 如果坐标形成闭合多边形,去除重复的最后一个点 if pixel_coords and pixel_coords[0] == pixel_coords[-1]: pixel_coords = pixel_coords[:-1] if not pixel_coords: print(f" 警告: 类别 '{class_name}' 没有有效的像素坐标") continue # 提取该类别下的所有像素 class_pixels = [] if len(pixel_coords) >= 3: # 对于多边形区域,提取边界框内的所有像素 min_row = min(r for r, c in pixel_coords) max_row = max(r for r, c in pixel_coords) min_col = min(c for r, c in pixel_coords) max_col = max(c for r, c in pixel_coords) # 提取边界框内的所有像素 for row in range(max(0, min_row), min(rows, max_row + 1)): for col in range(max(0, min_col), min(cols, max_col + 1)): spectrum = self.data[row, col, :] class_pixels.append(spectrum) else: # 对于少量点,直接提取这些点的像素 for row, col in pixel_coords: spectrum = self.data[row, col, :] class_pixels.append(spectrum) if class_pixels: X_samples.extend(class_pixels) y_samples.extend([class_name] * len(class_pixels)) print(f" 类别 '{class_name}': {len(class_pixels)} 个样本") else: print(f" 警告: 类别 '{class_name}' 没有有效的样本") if not X_samples: raise ValueError("没有找到有效的训练样本") X_train = np.array(X_samples) y_train = np.array(y_samples) # 编码标签 y_train_encoded = self.label_encoder.fit_transform(y_train) print(f"训练样本提取完成: {X_train.shape[0]} 个样本, {X_train.shape[1]} 个波段, {len(np.unique(y_train_encoded))} 个类别") return X_train, y_train_encoded def train_model(self, X: np.ndarray, y: np.ndarray) -> None: """ 训练分类模型 Args: X: 训练特征 y: 训练标签 """ print(f"正在训练 {self.config.model_type} 模型...") # 数据标准化 if self.config.use_standardization and self.scaler is not None: X_scaled = self.scaler.fit_transform(X) print("应用数据标准化") else: X_scaled = X # 获取模型 if self.config.model_type not in self.classifiers: raise ValueError(f"不支持的模型类型: {self.config.model_type}") model = self.classifiers[self.config.model_type] # 设置模型参数 if hasattr(model, 'set_params'): model.set_params(**self.config.model_params) # 网格搜索调参 if self.config.use_grid_search and hasattr(model, '_get_param_grid'): print("执行网格搜索调参...") param_grid = model._get_param_grid() grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy', n_jobs=-1) grid_search.fit(X_scaled, y) self.model = grid_search.best_estimator_ print(f"最佳参数: {grid_search.best_params_}") print(f"最佳得分: {grid_search.best_score_:.4f}") else: # 直接训练模型 self.model = model self.model.fit(X_scaled, y) print("模型训练完成") # 模型评估和可视化输出 self._evaluate_and_visualize_model(X_scaled, y, self.config.model_type) def predict_image(self) -> None: """ 对整个图像进行预测并保存结果 """ print("正在对整个图像进行预测...") if self.data is None: raise ValueError("请先读取高光谱数据") if self.model is None: raise ValueError("请先训练模型") rows, cols, bands = self.data.shape # 预处理整个图像 print("预处理图像数据...") image_data = self.data.reshape(-1, bands) if self.config.use_standardization and self.scaler is not None: image_data_scaled = self.scaler.transform(image_data) else: image_data_scaled = image_data # 分批预测 predictions = [] n_pixels = image_data_scaled.shape[0] for i in tqdm(range(0, n_pixels, self.config.batch_size), desc="预测进度"): batch_end = min(i + self.config.batch_size, n_pixels) batch_data = image_data_scaled[i:batch_end] batch_predictions = self.model.predict(batch_data) predictions.extend(batch_predictions) predictions = np.array(predictions) prediction_image = predictions.reshape(rows, cols) # 保存结果 self._save_prediction_result(prediction_image) print(f"预测结果已保存到: {self.config.output_path}") def _save_prediction_result(self, prediction_image: np.ndarray) -> None: """ 保存预测结果到文件 Args: prediction_image: 预测结果图像 (rows, cols) """ # 确保输出目录存在 output_dir = os.path.dirname(self.config.output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) print(f"创建输出目录: {output_dir}") driver = gdal.GetDriverByName('ENVI') rows, cols = prediction_image.shape # 创建输出文件 out_dataset = driver.Create( self.config.output_path, cols, rows, 1, # 单波段分类结果 gdal.GDT_Byte ) # 设置地理信息 if self.geotransform: out_dataset.SetGeoTransform(self.geotransform) if self.projection: out_dataset.SetProjection(self.projection) # 写入数据 out_band = out_dataset.GetRasterBand(1) out_band.WriteArray(prediction_image.astype(np.uint8)) # 设置颜色表(可选) self._set_color_table(out_band) out_dataset = None def _set_color_table(self, band) -> None: """ 设置分类结果的颜色表 Args: band: GDAL波段对象 """ # 创建颜色表 color_table = gdal.ColorTable() # 为每个类别设置颜色 for i, class_name in enumerate(self.class_names): # 生成不同颜色(这里使用简单的颜色方案) r = (i * 137) % 256 # 使用质数生成颜色 g = (i * 157) % 256 b = (i * 173) % 256 color_table.SetColorEntry(i + 1, (r, g, b, 255)) band.SetColorTable(color_table) def run_classification(self) -> None: """ 执行完整的分类流程 """ try: # 读取数据 self.read_hyperspectral_data() self.parse_roi_xml() # 提取训练样本 X_train, y_train = self.extract_training_samples() # 训练模型 self.train_model(X_train, y_train) # 预测图像 self.predict_image() print("分类任务完成!") except Exception as e: print(f"分类过程中发生错误: {str(e)}") raise def _evaluate_and_visualize_model(self, X_scaled, y, model_type): """ 评估模型并生成可视化输出 Args: X_scaled: 标准化后的特征 y: 标签 model_type: 模型类型 """ try: # 分割训练和验证集 X_train, X_val, y_train, y_val = train_test_split( X_scaled, y, test_size=self.config.test_size, random_state=42, stratify=y ) # 在验证集上进行预测 y_pred = self.model.predict(X_val) # 生成混淆矩阵图和评价指标CSV self._generate_evaluation_outputs(y_val, y_pred, model_type) except Exception as e: print(f"模型评估时出错: {e}") def _generate_evaluation_outputs(self, y_true, y_pred, model_type): """ 生成混淆矩阵图和评价指标CSV文件 Args: y_true: 真实标签 y_pred: 预测标签 model_type: 模型类型 """ try: # 生成混淆矩阵图 self._plot_confusion_matrix(y_true, y_pred, model_type) # 保存评价指标到CSV self._save_metrics_to_csv(y_true, y_pred, model_type) except Exception as e: print(f"生成评价输出时出错: {e}") def _plot_confusion_matrix(self, y_true, y_pred, model_type): """ 绘制混淆矩阵图 Args: y_true: 真实标签 y_pred: 预测标签 model_type: 模型类型 """ # 计算混淆矩阵 cm = confusion_matrix(y_true, y_pred) # 获取类别标签 unique_labels = sorted(list(set(y_true))) class_names = [f'Class {i}' for i in unique_labels] # 创建图表 plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.title(f'Confusion Matrix - {model_type.upper()}', fontsize=16, fontweight='bold') plt.ylabel('True Label', fontsize=12) plt.xlabel('Predicted Label', fontsize=12) # 保存图表 filename = f'confusion_matrix_{model_type}.png' plt.savefig(filename, dpi=300, bbox_inches='tight') plt.close() print(f"混淆矩阵图已保存: {filename}") def _save_metrics_to_csv(self, y_true, y_pred, model_type): """ 保存评价指标到CSV文件 Args: y_true: 真实标签 y_pred: 预测标签 model_type: 模型类型 """ # 计算各种评价指标 accuracy = accuracy_score(y_true, y_pred) # 获取类别标签 unique_labels = sorted(list(set(y_true))) class_names = [f'Class {i}' for i in unique_labels] report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True) # 计算每类的精确率、召回率、F1分数 precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None) # 准备数据 metrics_data = { 'model_type': [model_type] * len(class_names), 'class': class_names, 'precision': precision, 'recall': recall, 'f1_score': f1, 'support': support, 'accuracy': [accuracy] * len(class_names) } # 添加总体指标 metrics_data['model_type'].extend([model_type, model_type, model_type]) metrics_data['class'].extend(['macro_avg', 'weighted_avg', 'overall_accuracy']) metrics_data['precision'].extend([ report['macro avg']['precision'], report['weighted avg']['precision'], accuracy ]) metrics_data['recall'].extend([ report['macro avg']['recall'], report['weighted avg']['recall'], accuracy ]) metrics_data['f1_score'].extend([ report['macro avg']['f1-score'], report['weighted avg']['f1-score'], accuracy ]) metrics_data['support'].extend([ report['macro avg']['support'], report['weighted avg']['support'], len(y_true) ]) metrics_data['accuracy'].extend([accuracy, accuracy, accuracy]) # 创建DataFrame并保存 import pandas as pd df_metrics = pd.DataFrame(metrics_data) filename = f'model_metrics_{model_type}.csv' df_metrics.to_csv(filename, index=False) print(f"模型评价指标已保存: {filename}") class BaseClassifier: """基类分类器""" def __init__(self, name): self.name = name self.model = None def fit(self, X, y): pass def predict(self, X): pass def predict_proba(self, X): pass def get_params(self): """获取参数字典""" return {} def set_params(self, **params): """设置参数""" for param, value in params.items(): if hasattr(self, param): setattr(self, param, value) return self class EuclideanDistanceClassifier(BaseClassifier): """欧氏距离分类器""" def __init__(self): super().__init__('Euclidean Distance') self.class_centroids = {} self.class_labels = [] def fit(self, X, y): """计算每个类别的平均光谱(质心)""" unique_labels = np.unique(y) self.class_labels = unique_labels for label in unique_labels: class_samples = X[y == label] centroid = np.mean(class_samples, axis=0) self.class_centroids[label] = centroid def predict(self, X): """预测:选择最近的质心""" predictions = [] for sample in X: distances = [] for label, centroid in self.class_centroids.items(): # 计算欧氏距离 distance = np.linalg.norm(sample - centroid) distances.append((label, distance)) # 选择距离最小的类别 best_label = min(distances, key=lambda x: x[1])[0] predictions.append(best_label) return np.array(predictions) class MahalanobisDistanceClassifier(BaseClassifier): """马氏距离分类器""" def __init__(self): super().__init__('Mahalanobis Distance') self.class_means = {} self.class_covariances = {} self.class_inv_covariances = {} self.class_labels = [] def fit(self, X, y): """计算每个类别的均值和协方差矩阵""" unique_labels = np.unique(y) self.class_labels = unique_labels for label in unique_labels: class_samples = X[y == label] mean = np.mean(class_samples, axis=0) covariance = np.cov(class_samples.T) # 添加小值以避免奇异性 covariance += np.eye(covariance.shape[0]) * 1e-6 self.class_means[label] = mean self.class_covariances[label] = covariance # 计算协方差矩阵的逆 try: inv_cov = np.linalg.inv(covariance) self.class_inv_covariances[label] = inv_cov except np.linalg.LinAlgError: # 如果矩阵不可逆,使用伪逆 self.class_inv_covariances[label] = np.linalg.pinv(covariance) def mahalanobis_distance(self, x, mean, inv_cov): """计算马氏距离""" diff = x - mean return np.sqrt(np.dot(np.dot(diff.T, inv_cov), diff)) def predict(self, X): """预测:使用马氏距离""" predictions = [] for sample in X: distances = [] for label, mean in self.class_means.items(): inv_cov = self.class_inv_covariances[label] distance = self.mahalanobis_distance(sample, mean, inv_cov) distances.append((label, distance)) # 选择距离最小的类别 best_label = min(distances, key=lambda x: x[1])[0] predictions.append(best_label) return np.array(predictions) class SpectralAngleMapper(BaseClassifier): """光谱角制图分类器""" def __init__(self): super().__init__('Spectral Angle Mapper') self.class_spectra = {} self.class_labels = [] def fit(self, X, y): """计算每个类别的平均光谱""" unique_labels = np.unique(y) self.class_labels = unique_labels for label in unique_labels: class_samples = X[y == label] # 计算平均光谱 mean_spectrum = np.mean(class_samples, axis=0) self.class_spectra[label] = mean_spectrum def spectral_angle(self, spectrum1, spectrum2): """计算光谱角""" dot_product = np.dot(spectrum1, spectrum2) norm1 = np.linalg.norm(spectrum1) norm2 = np.linalg.norm(spectrum2) if norm1 == 0 or norm2 == 0: return np.pi / 2 # 90度 cos_angle = dot_product / (norm1 * norm2) # 确保在有效范围内 cos_angle = np.clip(cos_angle, -1.0, 1.0) return np.arccos(cos_angle) def predict(self, X): """预测:使用最小光谱角""" predictions = [] for sample in X: angles = [] for label, ref_spectrum in self.class_spectra.items(): angle = self.spectral_angle(sample, ref_spectrum) angles.append((label, angle)) # 选择角度最小的类别 best_label = min(angles, key=lambda x: x[1])[0] predictions.append(best_label) return np.array(predictions) class SpectralUnmixClassifier(BaseClassifier): """光谱解混分类器""" def __init__(self, n_endmembers=10): super().__init__('Spectral Unmix') self.n_endmembers = n_endmembers self.endmembers = None self.kmeans = None self.class_endmember_map = {} def get_params(self, deep=True): """获取参数字典""" return {'n_endmembers': self.n_endmembers} def set_params(self, **params): """设置参数""" for param, value in params.items(): if param == 'n_endmembers': self.n_endmembers = value return self def fit(self, X, y): """使用K-means提取端元,并建立端元与类别的映射""" unique_labels = np.unique(y) # 使用K-means提取端元 self.kmeans = KMeans(n_clusters=self.n_endmembers, random_state=42) self.kmeans.fit(X) self.endmembers = self.kmeans.cluster_centers_ # 建立端元与类别的映射 self.class_endmember_map = {} for label in unique_labels: # 找到属于该类别的样本 class_samples = X[y == label] if len(class_samples) > 0: # 预测这些样本的端元标签 endmember_labels = self.kmeans.predict(class_samples) # 找到最常出现的端元 most_common_endmember = np.argmax(np.bincount(endmember_labels)) self.class_endmember_map[label] = most_common_endmember def spectral_unmix(self, spectrum): """使用非负最小二乘法进行光谱解混""" # 求解:spectrum = endmembers * abundances # 约束:abundances >= 0, sum(abundances) = 1 abundances, _ = nnls(self.endmembers.T, spectrum) # 归一化 if np.sum(abundances) > 0: abundances = abundances / np.sum(abundances) return abundances def predict(self, X): """预测:使用最大丰度端元对应的类别""" predictions = [] for sample in X: # 计算丰度 abundances = self.spectral_unmix(sample) # 找到最大丰度的端元 dominant_endmember = np.argmax(abundances) # 找到该端元对应的类别 predicted_label = None max_count = 0 for label, endmember in self.class_endmember_map.items(): if endmember == dominant_endmember: # 简单投票:选择第一个匹配的类别 predicted_label = label break # 如果没有找到匹配的类别,选择丰度最大的类别 if predicted_label is None: # 这里我们可以改进,比如使用训练样本中端元的分布 predicted_label = list(self.class_endmember_map.keys())[0] predictions.append(predicted_label) return np.array(predictions) class PLSDAClassifier(BaseClassifier): """偏最小二乘判别分析分类器""" def __init__(self, n_components=2): super().__init__('PLS-DA') self.n_components = n_components self.pls = None self.class_labels = [] def get_params(self, deep=True): """获取参数字典""" return {'n_components': self.n_components} def set_params(self, **params): """设置参数""" for param, value in params.items(): if param == 'n_components': self.n_components = value return self def fit(self, X, y): """训练PLS-DA模型""" unique_labels = np.unique(y) self.class_labels = unique_labels # 创建one-hot编码的响应矩阵 n_samples = len(y) n_classes = len(unique_labels) Y_onehot = np.zeros((n_samples, n_classes)) for i, label in enumerate(unique_labels): Y_onehot[y == label, i] = 1 # 训练PLS回归模型 self.pls = PLSRegression(n_components=self.n_components) self.pls.fit(X, Y_onehot) def predict(self, X): """预测:选择PLS得分最高的类别""" # 得到PLS预测得分 Y_pred = self.pls.predict(X) # 选择得分最高的类别 predictions = np.argmax(Y_pred, axis=1) # 映射回原始标签 predictions = [self.class_labels[p] for p in predictions] return np.array(predictions) class HyperspectralClassifier: """高光谱图像分类器""" def __init__(self): self.data = None self.geotransform = None self.projection = None self.roi_data = {} self.scaler = StandardScaler() self.label_encoder = LabelEncoder() self.model = None self.class_names = [] # 注册所有可用的分类器 self.classifiers = { 'euclidean': EuclideanDistanceClassifier(), 'mahalanobis': MahalanobisDistanceClassifier(), 'linear_discriminant': LinearDiscriminantAnalysis(), 'quadratic_discriminant': QuadraticDiscriminantAnalysis(), 'plsda': PLSDAClassifier(), 'spectral_angle_mapper': SpectralAngleMapper(), 'spectral_unmix': SpectralUnmixClassifier(), 'random_forest': RandomForestClassifier(random_state=42), 'svm': SVC(random_state=42, probability=True), 'logistic_regression': LogisticRegression(random_state=42, max_iter=1000), 'knn': KNeighborsClassifier(), 'adaboost': AdaBoostClassifier(random_state=42) } # 条件添加梯度提升模型 if XGBOOST_AVAILABLE: self.classifiers['xgboost'] = XGBClassifier(random_state=42, use_label_encoder=False, eval_metric='logloss') if LIGHTGBM_AVAILABLE: self.classifiers['lightgbm'] = LGBMClassifier(random_state=42, verbose=-1) if CATBOOST_AVAILABLE: self.classifiers['catboost'] = CatBoostClassifier(random_state=42, verbose=False) def read_hyperspectral_data(self, file_path): """ 读取高光谱数据文件 Args: file_path (str): 高光谱文件路径 """ print(f"正在读取高光谱文件: {file_path}") # 使用GDAL读取数据 dataset = gdal.Open(file_path) if dataset is None: raise ValueError(f"无法打开文件: {file_path}") # 获取基本信息 cols = dataset.RasterXSize rows = dataset.RasterYSize bands = dataset.RasterCount print(f"图像尺寸: {cols} x {rows} x {bands}") # 读取所有波段数据 self.data = np.zeros((rows, cols, bands), dtype=np.float32) for band_idx in tqdm(range(bands), desc="读取波段"): band = dataset.GetRasterBand(band_idx + 1) band_data = band.ReadAsArray() self.data[:, :, band_idx] = band_data.astype(np.float32) # 获取地理信息 self.geotransform = dataset.GetGeoTransform() self.projection = dataset.GetProjection() dataset = None print("高光谱数据读取完成") def parse_roi_xml(self, xml_path): """ 解析XML ROI文件 Args: xml_path (str): XML文件路径 """ print(f"正在解析ROI文件: {xml_path}") tree = ET.parse(xml_path) root = tree.getroot() self.roi_data = {} self.class_names = [] for region in root.findall('Region'): name = region.get('name') color = region.get('color') # 解析几何信息 geom_def = region.find('GeometryDef') if geom_def is None: continue coord_sys = geom_def.find('CoordSysStr') polygon = geom_def.find('Polygon') if polygon is None: continue exterior = polygon.find('Exterior') if exterior is None: continue linear_ring = exterior.find('LinearRing') if linear_ring is None: continue coordinates_elem = linear_ring.find('Coordinates') if coordinates_elem is None: continue # 解析坐标 coords_text = coordinates_elem.text.strip() coords = [float(x) for x in coords_text.split()] # 转换为(x, y)坐标对 polygon_coords = [] for i in range(0, len(coords), 2): x, y = coords[i], coords[i+1] polygon_coords.append((x, y)) self.roi_data[name] = { 'color': color, 'coordinates': polygon_coords } self.class_names.append(name) print(f"解析完成,共找到 {len(self.roi_data)} 个ROI区域") print(f"类别名称: {self.class_names}") def _point_in_polygon(self, x, y, polygon): """ 判断点是否在多边形内(射线法) Args: x, y: 点坐标 polygon: 多边形坐标列表 [(x1,y1), (x2,y2), ...] Returns: bool: 是否在多边形内 """ n = len(polygon) inside = False p1x, p1y = polygon[0] for i in range(1, n + 1): p2x, p2y = polygon[i % n] if y > min(p1y, p2y): if y <= max(p1y, p2y): if x <= max(p1x, p2x): if p1y != p2y: xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x if p1x == p2x or x <= xinters: inside = not inside p1x, p1y = p2x, p2y return inside def extract_training_samples(self): """ 从ROI区域提取训练样本 Returns: tuple: (X_train, y_train) 训练数据和标签 """ print("正在提取训练样本...") X_samples = [] y_samples = [] rows, cols, bands = self.data.shape for class_name, roi_info in self.roi_data.items(): print(f"处理类别: {class_name}") coords = roi_info['coordinates'] # 找到ROI的边界框 min_x = min(x for x, y in coords) max_x = max(x for x, y in coords) min_y = min(y for x, y in coords) max_y = max(y for x, y in coords) # 转换为像素坐标(假设坐标是像素坐标) min_col = max(0, int(min_x)) max_col = min(cols, int(max_x) + 1) min_row = max(0, int(min_y)) max_row = min(rows, int(max_y) + 1) class_samples = [] # 在边界框内检查每个像素 for row in range(min_row, max_row): for col in range(min_col, max_col): if self._point_in_polygon(col, row, coords): # 提取该像素的光谱数据 spectrum = self.data[row, col, :] if not np.all(np.isnan(spectrum)): # 跳过NaN值 class_samples.append(spectrum) print(f"类别 {class_name} 提取到 {len(class_samples)} 个样本") X_samples.extend(class_samples) y_samples.extend([class_name] * len(class_samples)) X = np.array(X_samples) y = np.array(y_samples) print(f"总训练样本数: {len(X)}") # 编码标签 y_encoded = self.label_encoder.fit_transform(y) # 标准化特征 X_scaled = self.scaler.fit_transform(X) return X_scaled, y_encoded def train_model(self, X, y, model_type='svm', use_grid_search=True, **kwargs): """ 训练分类模型 Args: X: 训练特征 y: 训练标签 model_type: 模型类型 use_grid_search: 是否使用网格搜索调参 **kwargs: 模型参数 """ print(f"正在训练{model_type.upper()}模型...") # 检查模型类型是否支持 if model_type not in self.classifiers: raise ValueError(f"不支持的模型类型: {model_type}。可用类型: {list(self.classifiers.keys())}") # 分割训练和验证集 X_train, X_val, y_train, y_val = train_test_split( X, y, test_size=self.config.test_size, random_state=42, stratify=y ) # 获取分类器实例 if use_grid_search: # 为每种分类器定义参数网格 param_grids = { 'svm': { 'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.001, 0.01, 0.1, 1], 'kernel': ['rbf', 'linear'] }, 'random_forest': { 'n_estimators': [50, 100, 200], 'max_depth': [None, 10, 20, 30], 'min_samples_split': [2, 5, 10], 'min_samples_leaf': [1, 2, 4] }, 'logistic_regression': { 'C': [0.001, 0.01, 0.1, 1, 10, 100], 'penalty': ['l1', 'l2'], 'solver': ['liblinear', 'saga'] }, 'knn': { 'n_neighbors': [3, 5, 7, 9, 11], 'weights': ['uniform', 'distance'], 'metric': ['euclidean', 'manhattan', 'minkowski'] }, 'linear_discriminant': { 'solver': ['svd', 'lsqr', 'eigen'], 'shrinkage': [None, 'auto', 0.1, 0.5, 0.9] }, 'quadratic_discriminant': { 'reg_param': [0.0, 0.1, 0.5, 1.0], 'store_covariance': [True, False] }, 'plsda': { 'n_components': [2, 3, 5, 10], }, 'spectral_unmix': { 'n_endmembers': [5, 10, 15, 20] }, 'adaboost': { 'n_estimators': [50, 100, 200], 'learning_rate': [0.01, 0.1, 1.0] }, 'xgboost': { 'n_estimators': [100, 200, 300], 'max_depth': [3, 6, 9], 'learning_rate': [0.01, 0.1, 0.2], 'subsample': [0.8, 1.0] }, 'lightgbm': { 'n_estimators': [100, 200, 300], 'max_depth': [3, 6, 9, -1], 'learning_rate': [0.01, 0.1, 0.2], 'num_leaves': [31, 50, 100] }, 'catboost': { 'iterations': [100, 200, 300], 'depth': [4, 6, 8, 10], 'learning_rate': [0.01, 0.1, 0.2], 'l2_leaf_reg': [1, 3, 5, 7] } } param_grid = param_grids.get(model_type, {}) if param_grid: base_model = self.classifiers[model_type] self.model = GridSearchCV( base_model, param_grid, cv=3, scoring='accuracy', verbose=1, n_jobs=-1 ) else: # 对于没有参数网格的分类器,直接使用 self.model = self.classifiers[model_type] else: # 不使用网格搜索,直接使用分类器 self.model = self.classifiers[model_type] # 应用自定义参数 if kwargs: self.model.set_params(**kwargs) # 训练模型 self.model.fit(X_train, y_train) # 输出最佳参数(如果使用网格搜索) if use_grid_search and hasattr(self.model, 'best_params_'): print(f"最佳参数: {self.model.best_params_}") # 在验证集上评估 y_pred = self.model.predict(X_val) print("验证集评估结果:") print(classification_report(y_val, y_pred, target_names=self.label_encoder.classes_)) # 保存模型 model_filename = f'trained_model_{model_type}.pkl' joblib.dump({ 'model': self.model, 'scaler': self.scaler, 'label_encoder': self.label_encoder, 'model_type': model_type }, model_filename) # 生成混淆矩阵图和评价指标CSV self._generate_evaluation_outputs(y_val, y_pred, model_type) print(f"模型训练完成并保存为: {model_filename}") return self.model def _evaluate_and_visualize_model(self, X_scaled, y, model_type): """ 评估模型并生成可视化输出 Args: X_scaled: 标准化后的特征 y: 标签 model_type: 模型类型 """ try: # 分割训练和验证集 X_train, X_val, y_train, y_val = train_test_split( X_scaled, y, test_size=self.config.test_size, random_state=42, stratify=y ) # 在验证集上进行预测 y_pred = self.model.predict(X_val) # 生成混淆矩阵图和评价指标CSV self._generate_evaluation_outputs(y_val, y_pred, model_type) except Exception as e: print(f"模型评估时出错: {e}") def _generate_evaluation_outputs(self, y_true, y_pred, model_type): """ 生成混淆矩阵图和评价指标CSV文件 Args: y_true: 真实标签 y_pred: 预测标签 model_type: 模型类型 """ try: # 生成混淆矩阵图 self._plot_confusion_matrix(y_true, y_pred, model_type) # 保存评价指标到CSV self._save_metrics_to_csv(y_true, y_pred, model_type) except Exception as e: print(f"生成评价输出时出错: {e}") def _plot_confusion_matrix(self, y_true, y_pred, model_type): """ 绘制混淆矩阵图 Args: y_true: 真实标签 y_pred: 预测标签 model_type: 模型类型 """ # 计算混淆矩阵 cm = confusion_matrix(y_true, y_pred) # 创建图表 plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=self.label_encoder.classes_, yticklabels=self.label_encoder.classes_) plt.title(f'Confusion Matrix - {model_type.upper()}', fontsize=16, fontweight='bold') plt.ylabel('True Label', fontsize=12) plt.xlabel('Predicted Label', fontsize=12) # 保存图表 filename = f'confusion_matrix_{model_type}.png' plt.savefig(filename, dpi=300, bbox_inches='tight') plt.close() print(f"混淆矩阵图已保存: {filename}") def _save_metrics_to_csv(self, y_true, y_pred, model_type): """ 保存评价指标到CSV文件 Args: y_true: 真实标签 y_pred: 预测标签 model_type: 模型类型 """ # 计算各种评价指标 accuracy = accuracy_score(y_true, y_pred) report = classification_report(y_true, y_pred, target_names=self.label_encoder.classes_, output_dict=True) # 计算每类的精确率、召回率、F1分数 precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None) # 准备数据 metrics_data = { 'model_type': [model_type] * len(self.label_encoder.classes_), 'class': list(self.label_encoder.classes_), 'precision': precision, 'recall': recall, 'f1_score': f1, 'support': support, 'accuracy': [accuracy] * len(self.label_encoder.classes_) } # 添加总体指标 metrics_data['model_type'].extend([model_type, model_type, model_type]) metrics_data['class'].extend(['macro_avg', 'weighted_avg', 'overall_accuracy']) metrics_data['precision'].extend([ report['macro avg']['precision'], report['weighted avg']['precision'], accuracy ]) metrics_data['recall'].extend([ report['macro avg']['recall'], report['weighted avg']['recall'], accuracy ]) metrics_data['f1_score'].extend([ report['macro avg']['f1-score'], report['weighted avg']['f1-score'], accuracy ]) metrics_data['support'].extend([ report['macro avg']['support'], report['weighted avg']['support'], len(y_true) ]) metrics_data['accuracy'].extend([accuracy, accuracy, accuracy]) # 创建DataFrame并保存 import pandas as pd df_metrics = pd.DataFrame(metrics_data) filename = f'model_metrics_{model_type}.csv' df_metrics.to_csv(filename, index=False) print(f"模型评价指标已保存: {filename}") def predict_image(self, output_path, batch_size=10000): """ 对整个图像进行预测 Args: output_path (str): 输出文件路径 batch_size (int): 批处理大小,用于内存优化 """ print("正在对整个图像进行预测...") rows, cols, bands = self.data.shape prediction_map = np.zeros((rows, cols), dtype=np.uint8) # 将数据重塑为2D数组进行预测 X_image = self.data.reshape(-1, bands) X_image_scaled = self.scaler.transform(X_image) # 分批预测以节省内存 predictions = [] for i in tqdm(range(0, X_image_scaled.shape[0], batch_size), desc="预测"): batch_X = X_image_scaled[i:i+batch_size] batch_pred = self.model.predict(batch_X) predictions.extend(batch_pred) prediction_map = np.array(predictions).reshape(rows, cols) # 保存为BIL格式 self._save_bil_format(prediction_map, output_path) print(f"预测完成,结果保存至: {output_path}") return prediction_map def _save_bil_format(self, prediction_map, output_path): """ 保存预测结果为单波段BIL格式,每个像素值对应类别标签 Args: prediction_map: 预测结果 output_path: 输出路径 """ # 确保输出目录存在 output_dir = os.path.dirname(output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) print(f"创建输出目录: {output_dir}") rows, cols = prediction_map.shape # 创建单波段BIL格式文件 driver = gdal.GetDriverByName('ENVI') output_file = driver.Create( output_path, cols, rows, 1, # 单波段 gdal.GDT_Byte, options=['INTERLEAVE=BIL'] ) if output_file is None: raise ValueError(f"无法创建输出文件: {output_path}") # 保存预测结果,每个像素的值直接是类别标签 output_file.GetRasterBand(1).WriteArray(prediction_map.astype(np.uint8)) output_file.GetRasterBand(1).SetDescription("Classification Result") # 设置地理信息 if self.geotransform: output_file.SetGeoTransform(self.geotransform) if self.projection: output_file.SetProjection(self.projection) output_file.FlushCache() output_file = None # 创建头文件 self._create_hdr_file(output_path) def _create_hdr_file(self, bil_path): """ 创建ENVI头文件 Args: bil_path: BIL文件路径 """ hdr_path = bil_path + '.hdr' rows, cols = self.data.shape[:2] class_names = self.label_encoder.classes_ with open(hdr_path, 'w') as f: f.write("ENVI\n") f.write("description = {\n") f.write(" 高光谱分类结果 - 单波段格式\n") f.write("}\n") f.write(f"samples = {cols}\n") f.write(f"lines = {rows}\n") f.write("bands = 1\n") # 单波段 f.write("header offset = 0\n") f.write("file type = ENVI Standard\n") f.write("data type = 1\n") f.write("interleave = bil\n") f.write("sensor type = Unknown\n") f.write("byte order = 0\n") # 波段名称 f.write("band names = {\n") f.write(" \"Classification\"\n") f.write("}\n") # 添加类别信息和颜色表 f.write(f"classes = {len(class_names)}\n") f.write("class names = {\n") for i, class_name in enumerate(class_names): f.write(f' "{class_name}"') if i < len(class_names) - 1: f.write(",") f.write("\n") f.write("}\n") # 创建颜色表 (RGB) f.write("class lookup = {\n") for i, class_name in enumerate(class_names): # 生成不同的颜色 (这里使用简单的颜色方案) r = (i * 137) % 256 # 使用黄金角度近似值生成不同颜色 g = (i * 157) % 256 b = (i * 173) % 256 f.write(f" {r}, {g}, {b}") if i < len(class_names) - 1: f.write(",") f.write("\n") f.write("}\n") print(f"头文件创建完成: {hdr_path}") def evaluate_model(self, X_test, y_test): """ 评估模型性能 Args: X_test: 测试特征 y_test: 测试标签 Returns: dict: 评估指标 """ if self.model is None: raise ValueError("模型未训练") # 预测 y_pred = self.model.predict(X_test) # 计算各项指标 accuracy = accuracy_score(y_test, y_pred) report = classification_report(y_test, y_pred, target_names=self.label_encoder.classes_, output_dict=True) conf_matrix = confusion_matrix(y_test, y_pred) return { 'accuracy': accuracy, 'classification_report': report, 'confusion_matrix': conf_matrix, 'predictions': y_pred } def parse_model_params(param_str): """ 解析模型参数字符串 Args: param_str (str): 参数字符串,如 "C=1.0,kernel=rbf" 或 "{'C': 1.0, 'kernel': 'rbf'}" Returns: dict: 解析后的参数字典 """ if not param_str: return {} try: # 尝试作为JSON解析 if param_str.strip().startswith('{'): return json.loads(param_str) # 尝试作为Python字典字面量解析 elif '=' in param_str: # 解析 key=value,key=value 格式 params = {} pairs = param_str.split(',') for pair in pairs: if '=' not in pair: continue key, value = pair.split('=', 1) key = key.strip() value = value.strip() # 尝试转换数值类型 if value.lower() in ('true', 'false'): value = value.lower() == 'true' elif value.replace('.', '').replace('-', '').isdigit(): if '.' in value: value = float(value) else: value = int(value) elif value.startswith('[') and value.endswith(']'): # 列表解析 value = ast.literal_eval(value) elif value.startswith('(') and value.endswith(')'): # 元组解析 value = ast.literal_eval(value) params[key] = value return params else: # 尝试直接解析为Python字典 return ast.literal_eval(param_str) except (json.JSONDecodeError, ValueError, SyntaxError) as e: print(f"警告: 参数解析失败 '{param_str}': {e}") print("使用默认参数") return {} def show_model_params_help(): """显示模型参数帮助信息""" print("\n模型参数使用说明:") print("-" * 50) print("格式1 (key=value): C=1.0,kernel=rbf,n_estimators=100") print("格式2 (JSON): {\"C\": 1.0, \"kernel\": \"rbf\", \"n_estimators\": 100}") print("格式3 (Python): {'C': 1.0, 'kernel': 'rbf', 'n_estimators': 100}") print() print("常用模型参数:") print("- SVM: C, kernel, gamma, degree") print("- Random Forest: n_estimators, max_depth, min_samples_split, min_samples_leaf") print("- Logistic Regression: C, penalty, solver, max_iter") print("- KNN: n_neighbors, weights, metric, p") print("- Linear Discriminant: solver, shrinkage, n_components") print("- Quadratic Discriminant: reg_param, store_covariance") print("- PLS-DA: n_components") print("- Spectral Unmix: n_endmembers") print("- Euclidean/Mahalanobis/Spectral Angle: 无参数") print() def list_supported_models(): """列出所有支持的模型""" classifier = HyperspectralClassifier() models = list(classifier.classifiers.keys()) print("支持的分类模型:") print("-" * 40) for i, model in enumerate(models, 1): print(f"{i:2}. {model}") print("-" * 40) print() return models def main(): parser = argparse.ArgumentParser(description='高光谱图像分类工具') parser.add_argument('--hyperspectral', required=True, help='高光谱文件路径') parser.add_argument('--roi', required=True, help='ROI XML文件路径') parser.add_argument('--model', default='svm', help='分类模型类型') parser.add_argument('--output', required=True, help='输出文件路径') parser.add_argument('--no-grid-search', action='store_true', help='禁用自动超参数调优') parser.add_argument('--model-params', type=str, default='', help='模型参数 (格式: key=value,key=value 或 JSON)') parser.add_argument('--list-models', action='store_true', help='列出所有支持的模型并退出') parser.add_argument('--help-params', action='store_true', help='显示模型参数使用帮助并退出') parser.add_argument('--batch-size', type=int, default=10000, help='预测时的批处理大小 (默认: 10000)') parser.add_argument('--no-standardization', action='store_true', help='禁用数据标准化') args = parser.parse_args() if args.list_models: list_supported_models() return if args.help_params: show_model_params_help() return # 解析模型参数 model_params = parse_model_params(args.model_params) # 创建配置对象 config = ClassificationConfig( hyperspectral_path=args.hyperspectral, roi_path=args.roi, output_path=args.output, model_type=args.model, use_grid_search=not args.no_grid_search, model_params=model_params, batch_size=args.batch_size, use_standardization=not args.no_standardization ) # 创建并运行分类器 classifier = ConfigurableHyperspectralClassifier(config) classifier.run_classification() # 便捷函数,用于简化使用 def classify_hyperspectral_image(config: ClassificationConfig) -> None: """ 使用配置对象进行高光谱图像分类的主要接口函数 Args: config: 分类配置对象 """ classifier = ConfigurableHyperspectralClassifier(config) classifier.run_classification() def classify_with_params(hyperspectral_path: str, roi_path: str, output_path: str, model_type: str = 'svm', use_grid_search: bool = True, model_params: Optional[Dict[str, Any]] = None, batch_size: int = 10000, use_standardization: bool = True) -> None: """ 使用参数进行高光谱图像分类的便捷函数 Args: hyperspectral_path: 高光谱文件路径 roi_path: ROI文件路径 output_path: 输出文件路径 model_type: 模型类型 use_grid_search: 是否使用网格搜索 model_params: 模型参数字典 batch_size: 批处理大小 use_standardization: 是否使用标准化 """ config = ClassificationConfig( hyperspectral_path=hyperspectral_path, roi_path=roi_path, output_path=output_path, model_type=model_type, use_grid_search=use_grid_search, model_params=model_params or {}, batch_size=batch_size, use_standardization=use_standardization ) classify_hyperspectral_image(config) def batch_classify_all_methods(): """批量测试所有可用的分类方法""" print("="*60) print("批量测试所有分类方法") print("="*60) # 检查输入文件是否存在 hyperspectral = r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip" roi = r"E:\code\spectronon\single_classsfication\data\roi.xml" output_dir = r"E:\code\spectronon\single_classsfication\output" # 确保输出目录存在 import os os.makedirs(output_dir, exist_ok=True) if not os.path.exists(hyperspectral): print(f"错误: 高光谱文件不存在: {hyperspectral}") return if not os.path.exists(roi): print(f"错误: ROI文件不存在: {roi}") return # 定义所有可用的分类器 available_classifiers = [ 'euclidean', 'mahalanobis', 'linear_discriminant', 'quadratic_discriminant', 'plsda', 'spectral_angle_mapper', 'spectral_unmix', 'random_forest', 'svm', 'logistic_regression', 'knn', 'adaboost' ] # 条件添加梯度提升模型 if XGBOOST_AVAILABLE: available_classifiers.append('xgboost') if LIGHTGBM_AVAILABLE: available_classifiers.append('lightgbm') if CATBOOST_AVAILABLE: available_classifiers.append('catboost') results = {} timing_results = {} print(f"将测试 {len(available_classifiers)} 种分类方法:") for i, method in enumerate(available_classifiers, 1): print(f"{i:2d}. {method}") print() for i, method in enumerate(available_classifiers, 1): start_time = time.time() print("-" * 50) print(f"[{i}/{len(available_classifiers)}] 测试方法: {method.upper()}") print("-" * 50) try: # 为每种方法设置不同的输出路径 output_path = os.path.join(output_dir, f'classification_{method}.bil') print(f"输入文件: {os.path.basename(hyperspectral)}") print(f"ROI文件: {os.path.basename(roi)}") print(f"输出文件: {os.path.basename(output_path)}") print(f"模型类型: {method}") # 临时切换到输出目录,确保混淆矩阵和CSV文件保存在正确位置 original_cwd = os.getcwd() os.chdir(output_dir) try: # 使用便捷函数进行分类 classify_with_params( hyperspectral_path=hyperspectral, roi_path=roi, output_path=output_path, model_type=method, use_grid_search=False, # 为了速度,关闭网格搜索 model_params={}, # 使用默认参数 batch_size=5000 # 减小批处理大小以节省内存 ) finally: # 恢复原始工作目录 os.chdir(original_cwd) end_time = time.time() elapsed_time = end_time - start_time timing_results[method] = elapsed_time results[method] = "成功" print(f"✅ {method} 分类完成,结果保存至: {output_path}") print(f"⏱️ 耗时: {elapsed_time:.2f}秒") except Exception as e: end_time = time.time() elapsed_time = end_time - start_time timing_results[method] = elapsed_time error_msg = str(e) results[method] = f"失败: {error_msg}" print(f"❌ {method} 分类失败: {error_msg}") print(f"⏱️ 耗时: {elapsed_time:.2f}秒") # 继续下一个方法,不中断整个流程 continue print() # 空行分隔不同方法的结果 # 输出总结报告 print("\n" + "="*60) print("批量分类测试完成报告") print("="*60) successful = sum(1 for result in results.values() if result == "成功") total = len(results) print(f"总方法数: {total}") print(f"成功方法数: {successful}") print(f"失败方法数: {total - successful}") print(".1f") print("\n详细结果:") for method, result in results.items(): status = "✅" if result == "成功" else "❌" print(f"{status} {method}: {result}") # 保存结果到文件 try: report_path = os.path.join(output_dir, 'batch_classification_report.txt') with open(report_path, 'w', encoding='utf-8') as f: f.write("批量分类测试报告\n") f.write("="*60 + "\n") f.write(f"测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"输入文件: {hyperspectral}\n") f.write(f"ROI文件: {roi}\n") f.write(f"输出目录: {output_dir}\n\n") f.write(f"统计信息:\n") f.write(f"- 总方法数: {total}\n") f.write(f"- 成功方法数: {successful}\n") f.write(f"- 失败方法数: {total - successful}\n") f.write(f"- 成功率: {successful/total*100:.1f}%\n\n") f.write("详细结果:\n") f.write("-" * 40 + "\n") # 成功的和失败的分开显示 successful_methods = [(m, r) for m, r in results.items() if r == "成功"] failed_methods = [(m, r) for m, r in results.items() if r != "成功"] if successful_methods: f.write("✅ 成功的方法:\n") for method, result in successful_methods: elapsed = timing_results.get(method, 0) f.write(f" • {method}: {result}\n") f.write(f" 输出文件: classification_{method}.bil\n") f.write(f" 处理时间: {elapsed:.2f}秒\n") f.write("\n") # 显示性能统计 if timing_results: successful_times = [timing_results[m] for m, _ in successful_methods] f.write("性能统计:\n") f.write(f" • 最快方法: {min(successful_times):.2f}秒\n") f.write(f" • 最慢方法: {max(successful_times):.2f}秒\n") f.write(f" • 平均时间: {sum(successful_times)/len(successful_times):.2f}秒\n") f.write("\n") if failed_methods: f.write("❌ 失败的方法:\n") for method, result in failed_methods: elapsed = timing_results.get(method, 0) f.write(f" • {method}: {result}\n") f.write(f" 处理时间: {elapsed:.2f}秒\n") f.write("\n") f.write("算法说明:\n") f.write("-" * 40 + "\n") f.write("传统机器学习:\n") f.write(" • svm: 支持向量机\n") f.write(" • random_forest: 随机森林\n") f.write(" • logistic_regression: 逻辑回归\n") f.write(" • knn: K近邻\n") f.write(" • adaboost: AdaBoost\n") f.write(" • linear_discriminant: 线性判别分析\n") f.write(" • quadratic_discriminant: 二次判别分析\n\n") f.write("高光谱专用算法:\n") f.write(" • euclidean: 欧几里得距离分类\n") f.write(" • mahalanobis: 马哈拉诺比斯距离分类\n") f.write(" • spectral_angle_mapper: 光谱角度映射\n") f.write(" • spectral_unmix: 光谱解混\n") f.write(" • plsda: 偏最小二乘判别分析\n\n") f.write("梯度提升算法:\n") f.write(" • xgboost: XGBoost\n") f.write(" • lightgbm: LightGBM\n") f.write(" • catboost: CatBoost\n") print(f"\n📄 详细报告已保存至: {report_path}") # 显示成功的输出文件 if successful_methods: print("\n📁 生成的分类结果文件:") for method, _ in successful_methods: bil_file = os.path.join(output_dir, f'classification_{method}.bil') png_file = f'confusion_matrix_{method}.png' csv_file = f'model_metrics_{method}.csv' if os.path.exists(bil_file): print(f" • 分类结果: {bil_file}") if os.path.exists(png_file): print(f" • 混淆矩阵: {png_file}") if os.path.exists(csv_file): print(f" • 评价指标: {csv_file}") print() except Exception as e: print(f"保存报告失败: {e}") if __name__ == '__main__': # 批量测试所有分类方法 batch_classify_all_methods() # 单个方法测试示例(注释掉) """ # 示例用法 - 使用单个模型 hyperspectral = r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip" roi = r"E:\code\spectronon\single_classsfication\data\roi.xml" model = 'svm' # 可以改为其他模型 output = r'E:\code\spectronon\single_classsfication\output\test.bil' try: classify_with_params( hyperspectral_path=hyperspectral, roi_path=roi, output_path=output, model_type=model, use_grid_search=False, model_params={} ) except Exception as e: print(f"错误: {str(e)}") sys.exit(1) """