2084 lines
72 KiB
Python
2084 lines
72 KiB
Python
import os
|
||
import sys
|
||
import argparse
|
||
import json
|
||
import ast
|
||
import time
|
||
import numpy as np
|
||
from osgeo import gdal
|
||
import xml.etree.ElementTree as ET
|
||
from sklearn.model_selection import GridSearchCV, train_test_split
|
||
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
||
from sklearn.svm import SVC
|
||
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
|
||
from sklearn.linear_model import LogisticRegression
|
||
from sklearn.neighbors import KNeighborsClassifier
|
||
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
|
||
from sklearn.cross_decomposition import PLSRegression
|
||
from sklearn.decomposition import PCA
|
||
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
|
||
from sklearn.cluster import KMeans
|
||
from scipy.spatial.distance import cdist
|
||
from scipy.optimize import nnls
|
||
import joblib
|
||
from tqdm import tqdm
|
||
import warnings
|
||
from typing import Optional, Dict, Any, Tuple, List
|
||
from dataclasses import dataclass, field
|
||
|
||
# 可视化和数据保存导入
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from sklearn.metrics import precision_recall_fscore_support
|
||
|
||
# 导入梯度提升模型
|
||
try:
|
||
from xgboost import XGBClassifier
|
||
XGBOOST_AVAILABLE = True
|
||
except ImportError:
|
||
XGBOOST_AVAILABLE = False
|
||
print("警告: XGBoost未安装,将跳过XGBoost模型")
|
||
|
||
try:
|
||
from lightgbm import LGBMClassifier
|
||
LIGHTGBM_AVAILABLE = True
|
||
except ImportError:
|
||
LIGHTGBM_AVAILABLE = False
|
||
print("警告: LightGBM未安装,将跳过LightGBM模型")
|
||
|
||
try:
|
||
from catboost import CatBoostClassifier
|
||
CATBOOST_AVAILABLE = True
|
||
except ImportError:
|
||
CATBOOST_AVAILABLE = False
|
||
print("警告: CatBoost未安装,将跳过CatBoost模型")
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class ClassificationConfig:
|
||
"""高光谱图像分类配置类"""
|
||
|
||
# 输入文件配置
|
||
hyperspectral_path: Optional[str] = None
|
||
roi_path: Optional[str] = None
|
||
|
||
# 输出配置
|
||
output_path: Optional[str] = None
|
||
|
||
# 模型配置
|
||
model_type: str = 'svm'
|
||
use_grid_search: bool = True
|
||
model_params: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
# 预测配置
|
||
batch_size: int = 10000
|
||
|
||
# 数据预处理配置
|
||
use_standardization: bool = True
|
||
|
||
# 数据分割配置
|
||
test_size: float = 0.2
|
||
|
||
def __post_init__(self):
|
||
"""参数校验和默认值设置"""
|
||
# 校验必需的文件路径
|
||
if not self.hyperspectral_path:
|
||
raise ValueError("必须指定高光谱文件路径(hyperspectral_path)")
|
||
|
||
if not self.roi_path:
|
||
raise ValueError("必须指定ROI文件路径(roi_path)")
|
||
|
||
if not self.output_path:
|
||
raise ValueError("必须指定输出文件路径(output_path)")
|
||
|
||
# 校验文件存在性
|
||
if not os.path.exists(self.hyperspectral_path):
|
||
raise FileNotFoundError(f"高光谱文件不存在: {self.hyperspectral_path}")
|
||
|
||
if not os.path.exists(self.roi_path):
|
||
raise FileNotFoundError(f"ROI文件不存在: {self.roi_path}")
|
||
|
||
# 校验模型类型
|
||
supported_models = self._get_supported_models()
|
||
if self.model_type not in supported_models:
|
||
raise ValueError(f"不支持的模型类型: {self.model_type}。支持的模型: {list(supported_models.keys())}")
|
||
|
||
# 设置模型默认参数
|
||
self._set_model_default_params()
|
||
|
||
def _get_supported_models(self) -> Dict[str, str]:
|
||
"""获取支持的模型列表"""
|
||
models = {
|
||
'euclidean': 'Euclidean Distance',
|
||
'mahalanobis': 'Mahalanobis Distance',
|
||
'linear_discriminant': 'Linear Discriminant Analysis',
|
||
'quadratic_discriminant': 'Quadratic Discriminant Analysis',
|
||
'plsda': 'PLS Discriminant Analysis',
|
||
'spectral_angle_mapper': 'Spectral Angle Mapper',
|
||
'spectral_unmix': 'Spectral Unmixing',
|
||
'random_forest': 'Random Forest',
|
||
'svm': 'Support Vector Machine',
|
||
'logistic_regression': 'Logistic Regression',
|
||
'knn': 'K-Nearest Neighbors',
|
||
'adaboost': 'AdaBoost'
|
||
}
|
||
|
||
# 条件添加梯度提升模型
|
||
if XGBOOST_AVAILABLE:
|
||
models['xgboost'] = 'XGBoost'
|
||
if LIGHTGBM_AVAILABLE:
|
||
models['lightgbm'] = 'LightGBM'
|
||
if CATBOOST_AVAILABLE:
|
||
models['catboost'] = 'CatBoost'
|
||
|
||
return models
|
||
|
||
def _set_model_default_params(self):
|
||
"""根据模型类型设置默认参数"""
|
||
if self.model_type == 'svm':
|
||
self.model_params.setdefault('C', 1.0)
|
||
self.model_params.setdefault('kernel', 'rbf')
|
||
self.model_params.setdefault('gamma', 'scale')
|
||
elif self.model_type == 'random_forest':
|
||
self.model_params.setdefault('n_estimators', 100)
|
||
self.model_params.setdefault('max_depth', None)
|
||
self.model_params.setdefault('min_samples_split', 2)
|
||
elif self.model_type == 'logistic_regression':
|
||
self.model_params.setdefault('C', 1.0)
|
||
self.model_params.setdefault('max_iter', 1000)
|
||
elif self.model_type == 'knn':
|
||
self.model_params.setdefault('n_neighbors', 5)
|
||
self.model_params.setdefault('weights', 'uniform')
|
||
elif self.model_type == 'spectral_unmix':
|
||
self.model_params.setdefault('n_endmembers', 10)
|
||
elif self.model_type == 'adaboost':
|
||
self.model_params.setdefault('n_estimators', 50)
|
||
self.model_params.setdefault('learning_rate', 1.0)
|
||
# 为梯度提升模型设置默认参数
|
||
if self.model_type == 'xgboost' and XGBOOST_AVAILABLE:
|
||
self.model_params.setdefault('n_estimators', 100)
|
||
self.model_params.setdefault('max_depth', 6)
|
||
self.model_params.setdefault('learning_rate', 0.1)
|
||
elif self.model_type == 'lightgbm' and LIGHTGBM_AVAILABLE:
|
||
self.model_params.setdefault('n_estimators', 100)
|
||
self.model_params.setdefault('max_depth', -1)
|
||
self.model_params.setdefault('learning_rate', 0.1)
|
||
elif self.model_type == 'catboost' and CATBOOST_AVAILABLE:
|
||
self.model_params.setdefault('iterations', 100)
|
||
self.model_params.setdefault('depth', 6)
|
||
self.model_params.setdefault('learning_rate', 0.1)
|
||
|
||
|
||
class ConfigurableHyperspectralClassifier:
|
||
"""可配置的高光谱图像分类器"""
|
||
|
||
def __init__(self, config: ClassificationConfig):
|
||
self.config = config
|
||
|
||
# 初始化数据存储
|
||
self.data = None
|
||
self.geotransform = None
|
||
self.projection = None
|
||
self.roi_data = {}
|
||
self.scaler = StandardScaler() if config.use_standardization else None
|
||
self.label_encoder = LabelEncoder()
|
||
self.model = None
|
||
self.class_names = []
|
||
|
||
# 注册所有可用的分类器
|
||
self.classifiers = {
|
||
'euclidean': EuclideanDistanceClassifier(),
|
||
'mahalanobis': MahalanobisDistanceClassifier(),
|
||
'linear_discriminant': LinearDiscriminantAnalysis(),
|
||
'quadratic_discriminant': QuadraticDiscriminantAnalysis(),
|
||
'plsda': PLSDAClassifier(),
|
||
'spectral_angle_mapper': SpectralAngleMapper(),
|
||
'spectral_unmix': SpectralUnmixClassifier(),
|
||
'random_forest': RandomForestClassifier(random_state=42),
|
||
'svm': SVC(random_state=42, probability=True),
|
||
'logistic_regression': LogisticRegression(random_state=42, max_iter=1000),
|
||
'knn': KNeighborsClassifier(),
|
||
'adaboost': AdaBoostClassifier(random_state=42)
|
||
}
|
||
|
||
# 条件添加梯度提升模型
|
||
if XGBOOST_AVAILABLE:
|
||
self.classifiers['xgboost'] = XGBClassifier(
|
||
random_state=42, use_label_encoder=False, eval_metric='logloss', **config.model_params
|
||
)
|
||
if LIGHTGBM_AVAILABLE:
|
||
self.classifiers['lightgbm'] = LGBMClassifier(random_state=42, verbose=-1)
|
||
if CATBOOST_AVAILABLE:
|
||
self.classifiers['catboost'] = CatBoostClassifier(random_state=42, verbose=False)
|
||
|
||
def read_hyperspectral_data(self) -> None:
|
||
"""
|
||
读取高光谱数据文件
|
||
"""
|
||
print(f"正在读取高光谱文件: {self.config.hyperspectral_path}")
|
||
|
||
# 使用GDAL读取数据
|
||
dataset = gdal.Open(self.config.hyperspectral_path)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开文件: {self.config.hyperspectral_path}")
|
||
|
||
# 获取基本信息
|
||
cols = dataset.RasterXSize
|
||
rows = dataset.RasterYSize
|
||
bands = dataset.RasterCount
|
||
|
||
print(f"图像尺寸: {cols} x {rows} x {bands}")
|
||
|
||
# 读取所有波段数据
|
||
self.data = np.zeros((rows, cols, bands), dtype=np.float32)
|
||
|
||
for band_idx in tqdm(range(bands), desc="读取波段"):
|
||
band = dataset.GetRasterBand(band_idx + 1)
|
||
band_data = band.ReadAsArray()
|
||
self.data[:, :, band_idx] = band_data.astype(np.float32)
|
||
|
||
# 获取地理信息
|
||
self.geotransform = dataset.GetGeoTransform()
|
||
self.projection = dataset.GetProjection()
|
||
|
||
dataset = None
|
||
print("高光谱数据读取完成")
|
||
|
||
def parse_roi_xml(self) -> None:
|
||
"""
|
||
解析XML ROI文件
|
||
"""
|
||
print(f"正在解析ROI文件: {self.config.roi_path}")
|
||
|
||
tree = ET.parse(self.config.roi_path)
|
||
root = tree.getroot()
|
||
|
||
self.roi_data = {}
|
||
self.class_names = []
|
||
|
||
# 支持两种不同的路径查找方式
|
||
for region in root.findall('.//Region'):
|
||
name = region.get('name')
|
||
color = region.get('color')
|
||
|
||
coords = []
|
||
|
||
# 尝试第一种格式:单独的Point元素
|
||
for point in region.findall('.//Point'):
|
||
x_attr = point.get('x')
|
||
y_attr = point.get('y')
|
||
if x_attr is not None and y_attr is not None:
|
||
try:
|
||
x = float(x_attr)
|
||
y = float(y_attr)
|
||
coords.append((x, y))
|
||
except ValueError:
|
||
pass
|
||
|
||
# 如果通过Point元素没有找到坐标,尝试第二种格式:Coordinates文本
|
||
if not coords:
|
||
coords_element = region.find('.//Coordinates')
|
||
if coords_element is not None and coords_element.text:
|
||
coords_text = coords_element.text.strip()
|
||
numbers = coords_text.split()
|
||
|
||
if len(numbers) % 2 == 0:
|
||
try:
|
||
for i in range(0, len(numbers), 2):
|
||
x = float(numbers[i])
|
||
y = float(numbers[i + 1])
|
||
coords.append((x, y))
|
||
except ValueError:
|
||
coords = [] # 解析失败,清空坐标
|
||
|
||
if coords:
|
||
self.roi_data[name] = {
|
||
'coords': coords,
|
||
'color': color
|
||
}
|
||
self.class_names.append(name)
|
||
print(f"成功解析区域 '{name}',包含 {len(coords)} 个坐标点")
|
||
else:
|
||
print(f"警告: 区域 '{name}' 没有有效的坐标数据")
|
||
|
||
if self.roi_data:
|
||
print(f"解析完成,共找到 {len(self.class_names)} 个区域: {self.class_names}")
|
||
else:
|
||
print("警告: 没有找到任何有效的ROI区域数据")
|
||
|
||
def extract_training_samples(self) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
从ROI区域提取训练样本
|
||
|
||
Returns:
|
||
X_train: 训练特征 (n_samples, n_bands)
|
||
y_train: 训练标签 (n_samples,)
|
||
"""
|
||
print("正在提取训练样本...")
|
||
|
||
if self.data is None:
|
||
raise ValueError("请先读取高光谱数据")
|
||
|
||
if not self.roi_data:
|
||
raise ValueError("请先解析ROI文件")
|
||
|
||
rows, cols, bands = self.data.shape
|
||
X_samples = []
|
||
y_samples = []
|
||
|
||
for class_name, region_data in self.roi_data.items():
|
||
print(f"处理类别 '{class_name}'...")
|
||
|
||
coords = region_data['coords']
|
||
|
||
# 将坐标转换为像素位置
|
||
pixel_coords = []
|
||
for x, y in coords:
|
||
# 如果有地理变换矩阵,使用它进行坐标转换
|
||
if self.geotransform:
|
||
# 从地理坐标转换为像素坐标
|
||
# GeoTransform: [top_left_x, pixel_width, 0, top_left_y, 0, pixel_height]
|
||
col = int((x - self.geotransform[0]) / self.geotransform[1])
|
||
row = int((y - self.geotransform[3]) / self.geotransform[5])
|
||
else:
|
||
# 假设坐标直接对应像素位置(列, 行)
|
||
col = int(x)
|
||
row = int(y)
|
||
|
||
# 确保坐标在图像范围内
|
||
if 0 <= row < rows and 0 <= col < cols:
|
||
pixel_coords.append((row, col))
|
||
|
||
# 如果坐标形成闭合多边形,去除重复的最后一个点
|
||
if pixel_coords and pixel_coords[0] == pixel_coords[-1]:
|
||
pixel_coords = pixel_coords[:-1]
|
||
|
||
if not pixel_coords:
|
||
print(f" 警告: 类别 '{class_name}' 没有有效的像素坐标")
|
||
continue
|
||
|
||
# 提取该类别下的所有像素
|
||
class_pixels = []
|
||
|
||
if len(pixel_coords) >= 3:
|
||
# 对于多边形区域,提取边界框内的所有像素
|
||
min_row = min(r for r, c in pixel_coords)
|
||
max_row = max(r for r, c in pixel_coords)
|
||
min_col = min(c for r, c in pixel_coords)
|
||
max_col = max(c for r, c in pixel_coords)
|
||
|
||
# 提取边界框内的所有像素
|
||
for row in range(max(0, min_row), min(rows, max_row + 1)):
|
||
for col in range(max(0, min_col), min(cols, max_col + 1)):
|
||
spectrum = self.data[row, col, :]
|
||
class_pixels.append(spectrum)
|
||
else:
|
||
# 对于少量点,直接提取这些点的像素
|
||
for row, col in pixel_coords:
|
||
spectrum = self.data[row, col, :]
|
||
class_pixels.append(spectrum)
|
||
|
||
if class_pixels:
|
||
X_samples.extend(class_pixels)
|
||
y_samples.extend([class_name] * len(class_pixels))
|
||
print(f" 类别 '{class_name}': {len(class_pixels)} 个样本")
|
||
else:
|
||
print(f" 警告: 类别 '{class_name}' 没有有效的样本")
|
||
|
||
if not X_samples:
|
||
raise ValueError("没有找到有效的训练样本")
|
||
|
||
X_train = np.array(X_samples)
|
||
y_train = np.array(y_samples)
|
||
|
||
# 编码标签
|
||
y_train_encoded = self.label_encoder.fit_transform(y_train)
|
||
|
||
print(f"训练样本提取完成: {X_train.shape[0]} 个样本, {X_train.shape[1]} 个波段, {len(np.unique(y_train_encoded))} 个类别")
|
||
|
||
return X_train, y_train_encoded
|
||
|
||
def train_model(self, X: np.ndarray, y: np.ndarray) -> None:
|
||
"""
|
||
训练分类模型
|
||
|
||
Args:
|
||
X: 训练特征
|
||
y: 训练标签
|
||
"""
|
||
print(f"正在训练 {self.config.model_type} 模型...")
|
||
|
||
# 数据标准化
|
||
if self.config.use_standardization and self.scaler is not None:
|
||
X_scaled = self.scaler.fit_transform(X)
|
||
print("应用数据标准化")
|
||
else:
|
||
X_scaled = X
|
||
|
||
# 获取模型
|
||
if self.config.model_type not in self.classifiers:
|
||
raise ValueError(f"不支持的模型类型: {self.config.model_type}")
|
||
|
||
model = self.classifiers[self.config.model_type]
|
||
|
||
# 设置模型参数
|
||
if hasattr(model, 'set_params'):
|
||
model.set_params(**self.config.model_params)
|
||
|
||
# 网格搜索调参
|
||
if self.config.use_grid_search and hasattr(model, '_get_param_grid'):
|
||
print("执行网格搜索调参...")
|
||
param_grid = model._get_param_grid()
|
||
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy', n_jobs=-1)
|
||
grid_search.fit(X_scaled, y)
|
||
|
||
self.model = grid_search.best_estimator_
|
||
print(f"最佳参数: {grid_search.best_params_}")
|
||
print(f"最佳得分: {grid_search.best_score_:.4f}")
|
||
else:
|
||
# 直接训练模型
|
||
self.model = model
|
||
self.model.fit(X_scaled, y)
|
||
print("模型训练完成")
|
||
|
||
# 模型评估和可视化输出
|
||
self._evaluate_and_visualize_model(X_scaled, y, self.config.model_type)
|
||
|
||
def predict_image(self) -> None:
|
||
"""
|
||
对整个图像进行预测并保存结果
|
||
"""
|
||
print("正在对整个图像进行预测...")
|
||
|
||
if self.data is None:
|
||
raise ValueError("请先读取高光谱数据")
|
||
|
||
if self.model is None:
|
||
raise ValueError("请先训练模型")
|
||
|
||
rows, cols, bands = self.data.shape
|
||
|
||
# 预处理整个图像
|
||
print("预处理图像数据...")
|
||
image_data = self.data.reshape(-1, bands)
|
||
|
||
if self.config.use_standardization and self.scaler is not None:
|
||
image_data_scaled = self.scaler.transform(image_data)
|
||
else:
|
||
image_data_scaled = image_data
|
||
|
||
# 分批预测
|
||
predictions = []
|
||
n_pixels = image_data_scaled.shape[0]
|
||
|
||
for i in tqdm(range(0, n_pixels, self.config.batch_size), desc="预测进度"):
|
||
batch_end = min(i + self.config.batch_size, n_pixels)
|
||
batch_data = image_data_scaled[i:batch_end]
|
||
batch_predictions = self.model.predict(batch_data)
|
||
predictions.extend(batch_predictions)
|
||
|
||
predictions = np.array(predictions)
|
||
prediction_image = predictions.reshape(rows, cols)
|
||
|
||
# 保存结果
|
||
self._save_prediction_result(prediction_image)
|
||
print(f"预测结果已保存到: {self.config.output_path}")
|
||
|
||
def _save_prediction_result(self, prediction_image: np.ndarray) -> None:
|
||
"""
|
||
保存预测结果到文件
|
||
|
||
Args:
|
||
prediction_image: 预测结果图像 (rows, cols)
|
||
"""
|
||
# 确保输出目录存在
|
||
output_dir = os.path.dirname(self.config.output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
print(f"创建输出目录: {output_dir}")
|
||
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
rows, cols = prediction_image.shape
|
||
|
||
# 创建输出文件
|
||
out_dataset = driver.Create(
|
||
self.config.output_path,
|
||
cols, rows, 1, # 单波段分类结果
|
||
gdal.GDT_Byte
|
||
)
|
||
|
||
# 设置地理信息
|
||
if self.geotransform:
|
||
out_dataset.SetGeoTransform(self.geotransform)
|
||
if self.projection:
|
||
out_dataset.SetProjection(self.projection)
|
||
|
||
# 写入数据
|
||
out_band = out_dataset.GetRasterBand(1)
|
||
out_band.WriteArray(prediction_image.astype(np.uint8))
|
||
|
||
# 设置颜色表(可选)
|
||
self._set_color_table(out_band)
|
||
|
||
out_dataset = None
|
||
|
||
def _set_color_table(self, band) -> None:
|
||
"""
|
||
设置分类结果的颜色表
|
||
|
||
Args:
|
||
band: GDAL波段对象
|
||
"""
|
||
# 创建颜色表
|
||
color_table = gdal.ColorTable()
|
||
|
||
# 为每个类别设置颜色
|
||
for i, class_name in enumerate(self.class_names):
|
||
# 生成不同颜色(这里使用简单的颜色方案)
|
||
r = (i * 137) % 256 # 使用质数生成颜色
|
||
g = (i * 157) % 256
|
||
b = (i * 173) % 256
|
||
color_table.SetColorEntry(i + 1, (r, g, b, 255))
|
||
|
||
band.SetColorTable(color_table)
|
||
|
||
def run_classification(self) -> None:
|
||
"""
|
||
执行完整的分类流程
|
||
"""
|
||
try:
|
||
# 读取数据
|
||
self.read_hyperspectral_data()
|
||
self.parse_roi_xml()
|
||
|
||
# 提取训练样本
|
||
X_train, y_train = self.extract_training_samples()
|
||
|
||
# 训练模型
|
||
self.train_model(X_train, y_train)
|
||
|
||
# 预测图像
|
||
self.predict_image()
|
||
|
||
print("分类任务完成!")
|
||
|
||
except Exception as e:
|
||
print(f"分类过程中发生错误: {str(e)}")
|
||
raise
|
||
|
||
def _evaluate_and_visualize_model(self, X_scaled, y, model_type):
|
||
"""
|
||
评估模型并生成可视化输出
|
||
|
||
Args:
|
||
X_scaled: 标准化后的特征
|
||
y: 标签
|
||
model_type: 模型类型
|
||
"""
|
||
try:
|
||
# 分割训练和验证集
|
||
X_train, X_val, y_train, y_val = train_test_split(
|
||
X_scaled, y, test_size=self.config.test_size, random_state=42, stratify=y
|
||
)
|
||
|
||
# 在验证集上进行预测
|
||
y_pred = self.model.predict(X_val)
|
||
|
||
# 生成混淆矩阵图和评价指标CSV
|
||
self._generate_evaluation_outputs(y_val, y_pred, model_type)
|
||
|
||
except Exception as e:
|
||
print(f"模型评估时出错: {e}")
|
||
|
||
def _generate_evaluation_outputs(self, y_true, y_pred, model_type):
|
||
"""
|
||
生成混淆矩阵图和评价指标CSV文件
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
model_type: 模型类型
|
||
"""
|
||
try:
|
||
# 生成混淆矩阵图
|
||
self._plot_confusion_matrix(y_true, y_pred, model_type)
|
||
|
||
# 保存评价指标到CSV
|
||
self._save_metrics_to_csv(y_true, y_pred, model_type)
|
||
|
||
except Exception as e:
|
||
print(f"生成评价输出时出错: {e}")
|
||
|
||
def _plot_confusion_matrix(self, y_true, y_pred, model_type):
|
||
"""
|
||
绘制混淆矩阵图
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
model_type: 模型类型
|
||
"""
|
||
# 计算混淆矩阵
|
||
cm = confusion_matrix(y_true, y_pred)
|
||
|
||
# 获取类别标签
|
||
unique_labels = sorted(list(set(y_true)))
|
||
class_names = [f'Class {i}' for i in unique_labels]
|
||
|
||
# 创建图表
|
||
plt.figure(figsize=(10, 8))
|
||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||
xticklabels=class_names,
|
||
yticklabels=class_names)
|
||
|
||
plt.title(f'Confusion Matrix - {model_type.upper()}', fontsize=16, fontweight='bold')
|
||
plt.ylabel('True Label', fontsize=12)
|
||
plt.xlabel('Predicted Label', fontsize=12)
|
||
|
||
# 保存图表
|
||
filename = f'confusion_matrix_{model_type}.png'
|
||
plt.savefig(filename, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f"混淆矩阵图已保存: {filename}")
|
||
|
||
def _save_metrics_to_csv(self, y_true, y_pred, model_type):
|
||
"""
|
||
保存评价指标到CSV文件
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
model_type: 模型类型
|
||
"""
|
||
# 计算各种评价指标
|
||
accuracy = accuracy_score(y_true, y_pred)
|
||
|
||
# 获取类别标签
|
||
unique_labels = sorted(list(set(y_true)))
|
||
class_names = [f'Class {i}' for i in unique_labels]
|
||
|
||
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
|
||
|
||
# 计算每类的精确率、召回率、F1分数
|
||
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)
|
||
|
||
# 准备数据
|
||
metrics_data = {
|
||
'model_type': [model_type] * len(class_names),
|
||
'class': class_names,
|
||
'precision': precision,
|
||
'recall': recall,
|
||
'f1_score': f1,
|
||
'support': support,
|
||
'accuracy': [accuracy] * len(class_names)
|
||
}
|
||
|
||
# 添加总体指标
|
||
metrics_data['model_type'].extend([model_type, model_type, model_type])
|
||
metrics_data['class'].extend(['macro_avg', 'weighted_avg', 'overall_accuracy'])
|
||
metrics_data['precision'].extend([
|
||
report['macro avg']['precision'],
|
||
report['weighted avg']['precision'],
|
||
accuracy
|
||
])
|
||
metrics_data['recall'].extend([
|
||
report['macro avg']['recall'],
|
||
report['weighted avg']['recall'],
|
||
accuracy
|
||
])
|
||
metrics_data['f1_score'].extend([
|
||
report['macro avg']['f1-score'],
|
||
report['weighted avg']['f1-score'],
|
||
accuracy
|
||
])
|
||
metrics_data['support'].extend([
|
||
report['macro avg']['support'],
|
||
report['weighted avg']['support'],
|
||
len(y_true)
|
||
])
|
||
metrics_data['accuracy'].extend([accuracy, accuracy, accuracy])
|
||
|
||
# 创建DataFrame并保存
|
||
import pandas as pd
|
||
df_metrics = pd.DataFrame(metrics_data)
|
||
filename = f'model_metrics_{model_type}.csv'
|
||
df_metrics.to_csv(filename, index=False)
|
||
|
||
print(f"模型评价指标已保存: {filename}")
|
||
|
||
class BaseClassifier:
|
||
"""基类分类器"""
|
||
|
||
def __init__(self, name):
|
||
self.name = name
|
||
self.model = None
|
||
|
||
def fit(self, X, y):
|
||
pass
|
||
|
||
def predict(self, X):
|
||
pass
|
||
|
||
def predict_proba(self, X):
|
||
pass
|
||
|
||
def get_params(self):
|
||
"""获取参数字典"""
|
||
return {}
|
||
|
||
def set_params(self, **params):
|
||
"""设置参数"""
|
||
for param, value in params.items():
|
||
if hasattr(self, param):
|
||
setattr(self, param, value)
|
||
return self
|
||
|
||
|
||
class EuclideanDistanceClassifier(BaseClassifier):
|
||
"""欧氏距离分类器"""
|
||
|
||
def __init__(self):
|
||
super().__init__('Euclidean Distance')
|
||
self.class_centroids = {}
|
||
self.class_labels = []
|
||
|
||
def fit(self, X, y):
|
||
"""计算每个类别的平均光谱(质心)"""
|
||
unique_labels = np.unique(y)
|
||
self.class_labels = unique_labels
|
||
|
||
for label in unique_labels:
|
||
class_samples = X[y == label]
|
||
centroid = np.mean(class_samples, axis=0)
|
||
self.class_centroids[label] = centroid
|
||
|
||
def predict(self, X):
|
||
"""预测:选择最近的质心"""
|
||
predictions = []
|
||
for sample in X:
|
||
distances = []
|
||
for label, centroid in self.class_centroids.items():
|
||
# 计算欧氏距离
|
||
distance = np.linalg.norm(sample - centroid)
|
||
distances.append((label, distance))
|
||
|
||
# 选择距离最小的类别
|
||
best_label = min(distances, key=lambda x: x[1])[0]
|
||
predictions.append(best_label)
|
||
|
||
return np.array(predictions)
|
||
|
||
|
||
class MahalanobisDistanceClassifier(BaseClassifier):
|
||
"""马氏距离分类器"""
|
||
|
||
def __init__(self):
|
||
super().__init__('Mahalanobis Distance')
|
||
self.class_means = {}
|
||
self.class_covariances = {}
|
||
self.class_inv_covariances = {}
|
||
self.class_labels = []
|
||
|
||
def fit(self, X, y):
|
||
"""计算每个类别的均值和协方差矩阵"""
|
||
unique_labels = np.unique(y)
|
||
self.class_labels = unique_labels
|
||
|
||
for label in unique_labels:
|
||
class_samples = X[y == label]
|
||
mean = np.mean(class_samples, axis=0)
|
||
covariance = np.cov(class_samples.T)
|
||
|
||
# 添加小值以避免奇异性
|
||
covariance += np.eye(covariance.shape[0]) * 1e-6
|
||
|
||
self.class_means[label] = mean
|
||
self.class_covariances[label] = covariance
|
||
|
||
# 计算协方差矩阵的逆
|
||
try:
|
||
inv_cov = np.linalg.inv(covariance)
|
||
self.class_inv_covariances[label] = inv_cov
|
||
except np.linalg.LinAlgError:
|
||
# 如果矩阵不可逆,使用伪逆
|
||
self.class_inv_covariances[label] = np.linalg.pinv(covariance)
|
||
|
||
def mahalanobis_distance(self, x, mean, inv_cov):
|
||
"""计算马氏距离"""
|
||
diff = x - mean
|
||
return np.sqrt(np.dot(np.dot(diff.T, inv_cov), diff))
|
||
|
||
def predict(self, X):
|
||
"""预测:使用马氏距离"""
|
||
predictions = []
|
||
for sample in X:
|
||
distances = []
|
||
for label, mean in self.class_means.items():
|
||
inv_cov = self.class_inv_covariances[label]
|
||
distance = self.mahalanobis_distance(sample, mean, inv_cov)
|
||
distances.append((label, distance))
|
||
|
||
# 选择距离最小的类别
|
||
best_label = min(distances, key=lambda x: x[1])[0]
|
||
predictions.append(best_label)
|
||
|
||
return np.array(predictions)
|
||
|
||
|
||
class SpectralAngleMapper(BaseClassifier):
|
||
"""光谱角制图分类器"""
|
||
|
||
def __init__(self):
|
||
super().__init__('Spectral Angle Mapper')
|
||
self.class_spectra = {}
|
||
self.class_labels = []
|
||
|
||
def fit(self, X, y):
|
||
"""计算每个类别的平均光谱"""
|
||
unique_labels = np.unique(y)
|
||
self.class_labels = unique_labels
|
||
|
||
for label in unique_labels:
|
||
class_samples = X[y == label]
|
||
# 计算平均光谱
|
||
mean_spectrum = np.mean(class_samples, axis=0)
|
||
self.class_spectra[label] = mean_spectrum
|
||
|
||
def spectral_angle(self, spectrum1, spectrum2):
|
||
"""计算光谱角"""
|
||
dot_product = np.dot(spectrum1, spectrum2)
|
||
norm1 = np.linalg.norm(spectrum1)
|
||
norm2 = np.linalg.norm(spectrum2)
|
||
|
||
if norm1 == 0 or norm2 == 0:
|
||
return np.pi / 2 # 90度
|
||
|
||
cos_angle = dot_product / (norm1 * norm2)
|
||
# 确保在有效范围内
|
||
cos_angle = np.clip(cos_angle, -1.0, 1.0)
|
||
|
||
return np.arccos(cos_angle)
|
||
|
||
def predict(self, X):
|
||
"""预测:使用最小光谱角"""
|
||
predictions = []
|
||
for sample in X:
|
||
angles = []
|
||
for label, ref_spectrum in self.class_spectra.items():
|
||
angle = self.spectral_angle(sample, ref_spectrum)
|
||
angles.append((label, angle))
|
||
|
||
# 选择角度最小的类别
|
||
best_label = min(angles, key=lambda x: x[1])[0]
|
||
predictions.append(best_label)
|
||
|
||
return np.array(predictions)
|
||
|
||
|
||
class SpectralUnmixClassifier(BaseClassifier):
|
||
"""光谱解混分类器"""
|
||
|
||
def __init__(self, n_endmembers=10):
|
||
super().__init__('Spectral Unmix')
|
||
self.n_endmembers = n_endmembers
|
||
self.endmembers = None
|
||
self.kmeans = None
|
||
self.class_endmember_map = {}
|
||
|
||
def get_params(self, deep=True):
|
||
"""获取参数字典"""
|
||
return {'n_endmembers': self.n_endmembers}
|
||
|
||
def set_params(self, **params):
|
||
"""设置参数"""
|
||
for param, value in params.items():
|
||
if param == 'n_endmembers':
|
||
self.n_endmembers = value
|
||
return self
|
||
|
||
def fit(self, X, y):
|
||
"""使用K-means提取端元,并建立端元与类别的映射"""
|
||
unique_labels = np.unique(y)
|
||
|
||
# 使用K-means提取端元
|
||
self.kmeans = KMeans(n_clusters=self.n_endmembers, random_state=42)
|
||
self.kmeans.fit(X)
|
||
self.endmembers = self.kmeans.cluster_centers_
|
||
|
||
# 建立端元与类别的映射
|
||
self.class_endmember_map = {}
|
||
for label in unique_labels:
|
||
# 找到属于该类别的样本
|
||
class_samples = X[y == label]
|
||
if len(class_samples) > 0:
|
||
# 预测这些样本的端元标签
|
||
endmember_labels = self.kmeans.predict(class_samples)
|
||
# 找到最常出现的端元
|
||
most_common_endmember = np.argmax(np.bincount(endmember_labels))
|
||
self.class_endmember_map[label] = most_common_endmember
|
||
|
||
def spectral_unmix(self, spectrum):
|
||
"""使用非负最小二乘法进行光谱解混"""
|
||
# 求解:spectrum = endmembers * abundances
|
||
# 约束:abundances >= 0, sum(abundances) = 1
|
||
abundances, _ = nnls(self.endmembers.T, spectrum)
|
||
|
||
# 归一化
|
||
if np.sum(abundances) > 0:
|
||
abundances = abundances / np.sum(abundances)
|
||
|
||
return abundances
|
||
|
||
def predict(self, X):
|
||
"""预测:使用最大丰度端元对应的类别"""
|
||
predictions = []
|
||
|
||
for sample in X:
|
||
# 计算丰度
|
||
abundances = self.spectral_unmix(sample)
|
||
|
||
# 找到最大丰度的端元
|
||
dominant_endmember = np.argmax(abundances)
|
||
|
||
# 找到该端元对应的类别
|
||
predicted_label = None
|
||
max_count = 0
|
||
|
||
for label, endmember in self.class_endmember_map.items():
|
||
if endmember == dominant_endmember:
|
||
# 简单投票:选择第一个匹配的类别
|
||
predicted_label = label
|
||
break
|
||
|
||
# 如果没有找到匹配的类别,选择丰度最大的类别
|
||
if predicted_label is None:
|
||
# 这里我们可以改进,比如使用训练样本中端元的分布
|
||
predicted_label = list(self.class_endmember_map.keys())[0]
|
||
|
||
predictions.append(predicted_label)
|
||
|
||
return np.array(predictions)
|
||
|
||
|
||
class PLSDAClassifier(BaseClassifier):
|
||
"""偏最小二乘判别分析分类器"""
|
||
|
||
def __init__(self, n_components=2):
|
||
super().__init__('PLS-DA')
|
||
self.n_components = n_components
|
||
self.pls = None
|
||
self.class_labels = []
|
||
|
||
def get_params(self, deep=True):
|
||
"""获取参数字典"""
|
||
return {'n_components': self.n_components}
|
||
|
||
def set_params(self, **params):
|
||
"""设置参数"""
|
||
for param, value in params.items():
|
||
if param == 'n_components':
|
||
self.n_components = value
|
||
return self
|
||
|
||
def fit(self, X, y):
|
||
"""训练PLS-DA模型"""
|
||
unique_labels = np.unique(y)
|
||
self.class_labels = unique_labels
|
||
|
||
# 创建one-hot编码的响应矩阵
|
||
n_samples = len(y)
|
||
n_classes = len(unique_labels)
|
||
Y_onehot = np.zeros((n_samples, n_classes))
|
||
|
||
for i, label in enumerate(unique_labels):
|
||
Y_onehot[y == label, i] = 1
|
||
|
||
# 训练PLS回归模型
|
||
self.pls = PLSRegression(n_components=self.n_components)
|
||
self.pls.fit(X, Y_onehot)
|
||
|
||
def predict(self, X):
|
||
"""预测:选择PLS得分最高的类别"""
|
||
# 得到PLS预测得分
|
||
Y_pred = self.pls.predict(X)
|
||
|
||
# 选择得分最高的类别
|
||
predictions = np.argmax(Y_pred, axis=1)
|
||
|
||
# 映射回原始标签
|
||
predictions = [self.class_labels[p] for p in predictions]
|
||
|
||
return np.array(predictions)
|
||
|
||
|
||
class HyperspectralClassifier:
|
||
"""高光谱图像分类器"""
|
||
|
||
def __init__(self):
|
||
self.data = None
|
||
self.geotransform = None
|
||
self.projection = None
|
||
self.roi_data = {}
|
||
self.scaler = StandardScaler()
|
||
self.label_encoder = LabelEncoder()
|
||
self.model = None
|
||
self.class_names = []
|
||
|
||
# 注册所有可用的分类器
|
||
self.classifiers = {
|
||
'euclidean': EuclideanDistanceClassifier(),
|
||
'mahalanobis': MahalanobisDistanceClassifier(),
|
||
'linear_discriminant': LinearDiscriminantAnalysis(),
|
||
'quadratic_discriminant': QuadraticDiscriminantAnalysis(),
|
||
'plsda': PLSDAClassifier(),
|
||
'spectral_angle_mapper': SpectralAngleMapper(),
|
||
'spectral_unmix': SpectralUnmixClassifier(),
|
||
'random_forest': RandomForestClassifier(random_state=42),
|
||
'svm': SVC(random_state=42, probability=True),
|
||
'logistic_regression': LogisticRegression(random_state=42, max_iter=1000),
|
||
'knn': KNeighborsClassifier(),
|
||
'adaboost': AdaBoostClassifier(random_state=42)
|
||
}
|
||
|
||
# 条件添加梯度提升模型
|
||
if XGBOOST_AVAILABLE:
|
||
self.classifiers['xgboost'] = XGBClassifier(random_state=42, use_label_encoder=False, eval_metric='logloss')
|
||
if LIGHTGBM_AVAILABLE:
|
||
self.classifiers['lightgbm'] = LGBMClassifier(random_state=42, verbose=-1)
|
||
if CATBOOST_AVAILABLE:
|
||
self.classifiers['catboost'] = CatBoostClassifier(random_state=42, verbose=False)
|
||
|
||
def read_hyperspectral_data(self, file_path):
|
||
"""
|
||
读取高光谱数据文件
|
||
|
||
Args:
|
||
file_path (str): 高光谱文件路径
|
||
"""
|
||
print(f"正在读取高光谱文件: {file_path}")
|
||
|
||
# 使用GDAL读取数据
|
||
dataset = gdal.Open(file_path)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开文件: {file_path}")
|
||
|
||
# 获取基本信息
|
||
cols = dataset.RasterXSize
|
||
rows = dataset.RasterYSize
|
||
bands = dataset.RasterCount
|
||
|
||
print(f"图像尺寸: {cols} x {rows} x {bands}")
|
||
|
||
# 读取所有波段数据
|
||
self.data = np.zeros((rows, cols, bands), dtype=np.float32)
|
||
|
||
for band_idx in tqdm(range(bands), desc="读取波段"):
|
||
band = dataset.GetRasterBand(band_idx + 1)
|
||
band_data = band.ReadAsArray()
|
||
self.data[:, :, band_idx] = band_data.astype(np.float32)
|
||
|
||
# 获取地理信息
|
||
self.geotransform = dataset.GetGeoTransform()
|
||
self.projection = dataset.GetProjection()
|
||
|
||
dataset = None
|
||
print("高光谱数据读取完成")
|
||
|
||
def parse_roi_xml(self, xml_path):
|
||
"""
|
||
解析XML ROI文件
|
||
|
||
Args:
|
||
xml_path (str): XML文件路径
|
||
"""
|
||
print(f"正在解析ROI文件: {xml_path}")
|
||
|
||
tree = ET.parse(xml_path)
|
||
root = tree.getroot()
|
||
|
||
self.roi_data = {}
|
||
self.class_names = []
|
||
|
||
for region in root.findall('Region'):
|
||
name = region.get('name')
|
||
color = region.get('color')
|
||
|
||
# 解析几何信息
|
||
geom_def = region.find('GeometryDef')
|
||
if geom_def is None:
|
||
continue
|
||
|
||
coord_sys = geom_def.find('CoordSysStr')
|
||
polygon = geom_def.find('Polygon')
|
||
if polygon is None:
|
||
continue
|
||
|
||
exterior = polygon.find('Exterior')
|
||
if exterior is None:
|
||
continue
|
||
|
||
linear_ring = exterior.find('LinearRing')
|
||
if linear_ring is None:
|
||
continue
|
||
|
||
coordinates_elem = linear_ring.find('Coordinates')
|
||
if coordinates_elem is None:
|
||
continue
|
||
|
||
# 解析坐标
|
||
coords_text = coordinates_elem.text.strip()
|
||
coords = [float(x) for x in coords_text.split()]
|
||
|
||
# 转换为(x, y)坐标对
|
||
polygon_coords = []
|
||
for i in range(0, len(coords), 2):
|
||
x, y = coords[i], coords[i+1]
|
||
polygon_coords.append((x, y))
|
||
|
||
self.roi_data[name] = {
|
||
'color': color,
|
||
'coordinates': polygon_coords
|
||
}
|
||
self.class_names.append(name)
|
||
|
||
print(f"解析完成,共找到 {len(self.roi_data)} 个ROI区域")
|
||
print(f"类别名称: {self.class_names}")
|
||
|
||
def _point_in_polygon(self, x, y, polygon):
|
||
"""
|
||
判断点是否在多边形内(射线法)
|
||
|
||
Args:
|
||
x, y: 点坐标
|
||
polygon: 多边形坐标列表 [(x1,y1), (x2,y2), ...]
|
||
|
||
Returns:
|
||
bool: 是否在多边形内
|
||
"""
|
||
n = len(polygon)
|
||
inside = False
|
||
|
||
p1x, p1y = polygon[0]
|
||
for i in range(1, n + 1):
|
||
p2x, p2y = polygon[i % n]
|
||
if y > min(p1y, p2y):
|
||
if y <= max(p1y, p2y):
|
||
if x <= max(p1x, p2x):
|
||
if p1y != p2y:
|
||
xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
|
||
if p1x == p2x or x <= xinters:
|
||
inside = not inside
|
||
p1x, p1y = p2x, p2y
|
||
|
||
return inside
|
||
|
||
def extract_training_samples(self):
|
||
"""
|
||
从ROI区域提取训练样本
|
||
|
||
Returns:
|
||
tuple: (X_train, y_train) 训练数据和标签
|
||
"""
|
||
print("正在提取训练样本...")
|
||
|
||
X_samples = []
|
||
y_samples = []
|
||
|
||
rows, cols, bands = self.data.shape
|
||
|
||
for class_name, roi_info in self.roi_data.items():
|
||
print(f"处理类别: {class_name}")
|
||
coords = roi_info['coordinates']
|
||
|
||
# 找到ROI的边界框
|
||
min_x = min(x for x, y in coords)
|
||
max_x = max(x for x, y in coords)
|
||
min_y = min(y for x, y in coords)
|
||
max_y = max(y for x, y in coords)
|
||
|
||
# 转换为像素坐标(假设坐标是像素坐标)
|
||
min_col = max(0, int(min_x))
|
||
max_col = min(cols, int(max_x) + 1)
|
||
min_row = max(0, int(min_y))
|
||
max_row = min(rows, int(max_y) + 1)
|
||
|
||
class_samples = []
|
||
|
||
# 在边界框内检查每个像素
|
||
for row in range(min_row, max_row):
|
||
for col in range(min_col, max_col):
|
||
if self._point_in_polygon(col, row, coords):
|
||
# 提取该像素的光谱数据
|
||
spectrum = self.data[row, col, :]
|
||
if not np.all(np.isnan(spectrum)): # 跳过NaN值
|
||
class_samples.append(spectrum)
|
||
|
||
print(f"类别 {class_name} 提取到 {len(class_samples)} 个样本")
|
||
X_samples.extend(class_samples)
|
||
y_samples.extend([class_name] * len(class_samples))
|
||
|
||
X = np.array(X_samples)
|
||
y = np.array(y_samples)
|
||
|
||
print(f"总训练样本数: {len(X)}")
|
||
|
||
# 编码标签
|
||
y_encoded = self.label_encoder.fit_transform(y)
|
||
|
||
# 标准化特征
|
||
X_scaled = self.scaler.fit_transform(X)
|
||
|
||
return X_scaled, y_encoded
|
||
|
||
def train_model(self, X, y, model_type='svm', use_grid_search=True, **kwargs):
|
||
"""
|
||
训练分类模型
|
||
|
||
Args:
|
||
X: 训练特征
|
||
y: 训练标签
|
||
model_type: 模型类型
|
||
use_grid_search: 是否使用网格搜索调参
|
||
**kwargs: 模型参数
|
||
"""
|
||
print(f"正在训练{model_type.upper()}模型...")
|
||
|
||
# 检查模型类型是否支持
|
||
if model_type not in self.classifiers:
|
||
raise ValueError(f"不支持的模型类型: {model_type}。可用类型: {list(self.classifiers.keys())}")
|
||
|
||
# 分割训练和验证集
|
||
X_train, X_val, y_train, y_val = train_test_split(
|
||
X, y, test_size=self.config.test_size, random_state=42, stratify=y
|
||
)
|
||
|
||
# 获取分类器实例
|
||
if use_grid_search:
|
||
# 为每种分类器定义参数网格
|
||
param_grids = {
|
||
'svm': {
|
||
'C': [0.1, 1, 10, 100],
|
||
'gamma': ['scale', 'auto', 0.001, 0.01, 0.1, 1],
|
||
'kernel': ['rbf', 'linear']
|
||
},
|
||
'random_forest': {
|
||
'n_estimators': [50, 100, 200],
|
||
'max_depth': [None, 10, 20, 30],
|
||
'min_samples_split': [2, 5, 10],
|
||
'min_samples_leaf': [1, 2, 4]
|
||
},
|
||
'logistic_regression': {
|
||
'C': [0.001, 0.01, 0.1, 1, 10, 100],
|
||
'penalty': ['l1', 'l2'],
|
||
'solver': ['liblinear', 'saga']
|
||
},
|
||
'knn': {
|
||
'n_neighbors': [3, 5, 7, 9, 11],
|
||
'weights': ['uniform', 'distance'],
|
||
'metric': ['euclidean', 'manhattan', 'minkowski']
|
||
},
|
||
'linear_discriminant': {
|
||
'solver': ['svd', 'lsqr', 'eigen'],
|
||
'shrinkage': [None, 'auto', 0.1, 0.5, 0.9]
|
||
},
|
||
'quadratic_discriminant': {
|
||
'reg_param': [0.0, 0.1, 0.5, 1.0],
|
||
'store_covariance': [True, False]
|
||
},
|
||
'plsda': {
|
||
'n_components': [2, 3, 5, 10],
|
||
},
|
||
'spectral_unmix': {
|
||
'n_endmembers': [5, 10, 15, 20]
|
||
},
|
||
'adaboost': {
|
||
'n_estimators': [50, 100, 200],
|
||
'learning_rate': [0.01, 0.1, 1.0]
|
||
},
|
||
'xgboost': {
|
||
'n_estimators': [100, 200, 300],
|
||
'max_depth': [3, 6, 9],
|
||
'learning_rate': [0.01, 0.1, 0.2],
|
||
'subsample': [0.8, 1.0]
|
||
},
|
||
'lightgbm': {
|
||
'n_estimators': [100, 200, 300],
|
||
'max_depth': [3, 6, 9, -1],
|
||
'learning_rate': [0.01, 0.1, 0.2],
|
||
'num_leaves': [31, 50, 100]
|
||
},
|
||
'catboost': {
|
||
'iterations': [100, 200, 300],
|
||
'depth': [4, 6, 8, 10],
|
||
'learning_rate': [0.01, 0.1, 0.2],
|
||
'l2_leaf_reg': [1, 3, 5, 7]
|
||
}
|
||
}
|
||
|
||
param_grid = param_grids.get(model_type, {})
|
||
|
||
if param_grid:
|
||
base_model = self.classifiers[model_type]
|
||
self.model = GridSearchCV(
|
||
base_model, param_grid, cv=3, scoring='accuracy',
|
||
verbose=1, n_jobs=-1
|
||
)
|
||
else:
|
||
# 对于没有参数网格的分类器,直接使用
|
||
self.model = self.classifiers[model_type]
|
||
else:
|
||
# 不使用网格搜索,直接使用分类器
|
||
self.model = self.classifiers[model_type]
|
||
|
||
# 应用自定义参数
|
||
if kwargs:
|
||
self.model.set_params(**kwargs)
|
||
|
||
# 训练模型
|
||
self.model.fit(X_train, y_train)
|
||
|
||
# 输出最佳参数(如果使用网格搜索)
|
||
if use_grid_search and hasattr(self.model, 'best_params_'):
|
||
print(f"最佳参数: {self.model.best_params_}")
|
||
|
||
# 在验证集上评估
|
||
y_pred = self.model.predict(X_val)
|
||
print("验证集评估结果:")
|
||
print(classification_report(y_val, y_pred, target_names=self.label_encoder.classes_))
|
||
|
||
# 保存模型
|
||
model_filename = f'trained_model_{model_type}.pkl'
|
||
joblib.dump({
|
||
'model': self.model,
|
||
'scaler': self.scaler,
|
||
'label_encoder': self.label_encoder,
|
||
'model_type': model_type
|
||
}, model_filename)
|
||
|
||
# 生成混淆矩阵图和评价指标CSV
|
||
self._generate_evaluation_outputs(y_val, y_pred, model_type)
|
||
|
||
print(f"模型训练完成并保存为: {model_filename}")
|
||
return self.model
|
||
|
||
def _evaluate_and_visualize_model(self, X_scaled, y, model_type):
|
||
"""
|
||
评估模型并生成可视化输出
|
||
|
||
Args:
|
||
X_scaled: 标准化后的特征
|
||
y: 标签
|
||
model_type: 模型类型
|
||
"""
|
||
try:
|
||
# 分割训练和验证集
|
||
X_train, X_val, y_train, y_val = train_test_split(
|
||
X_scaled, y, test_size=self.config.test_size, random_state=42, stratify=y
|
||
)
|
||
|
||
# 在验证集上进行预测
|
||
y_pred = self.model.predict(X_val)
|
||
|
||
# 生成混淆矩阵图和评价指标CSV
|
||
self._generate_evaluation_outputs(y_val, y_pred, model_type)
|
||
|
||
except Exception as e:
|
||
print(f"模型评估时出错: {e}")
|
||
|
||
def _generate_evaluation_outputs(self, y_true, y_pred, model_type):
|
||
"""
|
||
生成混淆矩阵图和评价指标CSV文件
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
model_type: 模型类型
|
||
"""
|
||
try:
|
||
# 生成混淆矩阵图
|
||
self._plot_confusion_matrix(y_true, y_pred, model_type)
|
||
|
||
# 保存评价指标到CSV
|
||
self._save_metrics_to_csv(y_true, y_pred, model_type)
|
||
|
||
except Exception as e:
|
||
print(f"生成评价输出时出错: {e}")
|
||
|
||
def _plot_confusion_matrix(self, y_true, y_pred, model_type):
|
||
"""
|
||
绘制混淆矩阵图
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
model_type: 模型类型
|
||
"""
|
||
# 计算混淆矩阵
|
||
cm = confusion_matrix(y_true, y_pred)
|
||
|
||
# 创建图表
|
||
plt.figure(figsize=(10, 8))
|
||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||
xticklabels=self.label_encoder.classes_,
|
||
yticklabels=self.label_encoder.classes_)
|
||
|
||
plt.title(f'Confusion Matrix - {model_type.upper()}', fontsize=16, fontweight='bold')
|
||
plt.ylabel('True Label', fontsize=12)
|
||
plt.xlabel('Predicted Label', fontsize=12)
|
||
|
||
# 保存图表
|
||
filename = f'confusion_matrix_{model_type}.png'
|
||
plt.savefig(filename, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f"混淆矩阵图已保存: {filename}")
|
||
|
||
def _save_metrics_to_csv(self, y_true, y_pred, model_type):
|
||
"""
|
||
保存评价指标到CSV文件
|
||
|
||
Args:
|
||
y_true: 真实标签
|
||
y_pred: 预测标签
|
||
model_type: 模型类型
|
||
"""
|
||
# 计算各种评价指标
|
||
accuracy = accuracy_score(y_true, y_pred)
|
||
report = classification_report(y_true, y_pred, target_names=self.label_encoder.classes_, output_dict=True)
|
||
|
||
# 计算每类的精确率、召回率、F1分数
|
||
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)
|
||
|
||
# 准备数据
|
||
metrics_data = {
|
||
'model_type': [model_type] * len(self.label_encoder.classes_),
|
||
'class': list(self.label_encoder.classes_),
|
||
'precision': precision,
|
||
'recall': recall,
|
||
'f1_score': f1,
|
||
'support': support,
|
||
'accuracy': [accuracy] * len(self.label_encoder.classes_)
|
||
}
|
||
|
||
# 添加总体指标
|
||
metrics_data['model_type'].extend([model_type, model_type, model_type])
|
||
metrics_data['class'].extend(['macro_avg', 'weighted_avg', 'overall_accuracy'])
|
||
metrics_data['precision'].extend([
|
||
report['macro avg']['precision'],
|
||
report['weighted avg']['precision'],
|
||
accuracy
|
||
])
|
||
metrics_data['recall'].extend([
|
||
report['macro avg']['recall'],
|
||
report['weighted avg']['recall'],
|
||
accuracy
|
||
])
|
||
metrics_data['f1_score'].extend([
|
||
report['macro avg']['f1-score'],
|
||
report['weighted avg']['f1-score'],
|
||
accuracy
|
||
])
|
||
metrics_data['support'].extend([
|
||
report['macro avg']['support'],
|
||
report['weighted avg']['support'],
|
||
len(y_true)
|
||
])
|
||
metrics_data['accuracy'].extend([accuracy, accuracy, accuracy])
|
||
|
||
# 创建DataFrame并保存
|
||
import pandas as pd
|
||
df_metrics = pd.DataFrame(metrics_data)
|
||
filename = f'model_metrics_{model_type}.csv'
|
||
df_metrics.to_csv(filename, index=False)
|
||
|
||
print(f"模型评价指标已保存: {filename}")
|
||
|
||
|
||
def predict_image(self, output_path, batch_size=10000):
|
||
"""
|
||
对整个图像进行预测
|
||
|
||
Args:
|
||
output_path (str): 输出文件路径
|
||
batch_size (int): 批处理大小,用于内存优化
|
||
"""
|
||
print("正在对整个图像进行预测...")
|
||
|
||
rows, cols, bands = self.data.shape
|
||
prediction_map = np.zeros((rows, cols), dtype=np.uint8)
|
||
|
||
# 将数据重塑为2D数组进行预测
|
||
X_image = self.data.reshape(-1, bands)
|
||
X_image_scaled = self.scaler.transform(X_image)
|
||
|
||
# 分批预测以节省内存
|
||
predictions = []
|
||
|
||
for i in tqdm(range(0, X_image_scaled.shape[0], batch_size), desc="预测"):
|
||
batch_X = X_image_scaled[i:i+batch_size]
|
||
batch_pred = self.model.predict(batch_X)
|
||
predictions.extend(batch_pred)
|
||
|
||
prediction_map = np.array(predictions).reshape(rows, cols)
|
||
|
||
# 保存为BIL格式
|
||
self._save_bil_format(prediction_map, output_path)
|
||
|
||
print(f"预测完成,结果保存至: {output_path}")
|
||
return prediction_map
|
||
|
||
def _save_bil_format(self, prediction_map, output_path):
|
||
"""
|
||
保存预测结果为单波段BIL格式,每个像素值对应类别标签
|
||
|
||
Args:
|
||
prediction_map: 预测结果
|
||
output_path: 输出路径
|
||
"""
|
||
# 确保输出目录存在
|
||
output_dir = os.path.dirname(output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
print(f"创建输出目录: {output_dir}")
|
||
|
||
rows, cols = prediction_map.shape
|
||
|
||
# 创建单波段BIL格式文件
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
output_file = driver.Create(
|
||
output_path, cols, rows, 1, # 单波段
|
||
gdal.GDT_Byte, options=['INTERLEAVE=BIL']
|
||
)
|
||
|
||
if output_file is None:
|
||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||
|
||
# 保存预测结果,每个像素的值直接是类别标签
|
||
output_file.GetRasterBand(1).WriteArray(prediction_map.astype(np.uint8))
|
||
output_file.GetRasterBand(1).SetDescription("Classification Result")
|
||
|
||
# 设置地理信息
|
||
if self.geotransform:
|
||
output_file.SetGeoTransform(self.geotransform)
|
||
if self.projection:
|
||
output_file.SetProjection(self.projection)
|
||
|
||
output_file.FlushCache()
|
||
output_file = None
|
||
|
||
# 创建头文件
|
||
self._create_hdr_file(output_path)
|
||
|
||
def _create_hdr_file(self, bil_path):
|
||
"""
|
||
创建ENVI头文件
|
||
|
||
Args:
|
||
bil_path: BIL文件路径
|
||
"""
|
||
hdr_path = bil_path + '.hdr'
|
||
|
||
rows, cols = self.data.shape[:2]
|
||
class_names = self.label_encoder.classes_
|
||
|
||
with open(hdr_path, 'w') as f:
|
||
f.write("ENVI\n")
|
||
f.write("description = {\n")
|
||
f.write(" 高光谱分类结果 - 单波段格式\n")
|
||
f.write("}\n")
|
||
f.write(f"samples = {cols}\n")
|
||
f.write(f"lines = {rows}\n")
|
||
f.write("bands = 1\n") # 单波段
|
||
f.write("header offset = 0\n")
|
||
f.write("file type = ENVI Standard\n")
|
||
f.write("data type = 1\n")
|
||
f.write("interleave = bil\n")
|
||
f.write("sensor type = Unknown\n")
|
||
f.write("byte order = 0\n")
|
||
|
||
# 波段名称
|
||
f.write("band names = {\n")
|
||
f.write(" \"Classification\"\n")
|
||
f.write("}\n")
|
||
|
||
# 添加类别信息和颜色表
|
||
f.write(f"classes = {len(class_names)}\n")
|
||
f.write("class names = {\n")
|
||
for i, class_name in enumerate(class_names):
|
||
f.write(f' "{class_name}"')
|
||
if i < len(class_names) - 1:
|
||
f.write(",")
|
||
f.write("\n")
|
||
f.write("}\n")
|
||
|
||
# 创建颜色表 (RGB)
|
||
f.write("class lookup = {\n")
|
||
for i, class_name in enumerate(class_names):
|
||
# 生成不同的颜色 (这里使用简单的颜色方案)
|
||
r = (i * 137) % 256 # 使用黄金角度近似值生成不同颜色
|
||
g = (i * 157) % 256
|
||
b = (i * 173) % 256
|
||
f.write(f" {r}, {g}, {b}")
|
||
if i < len(class_names) - 1:
|
||
f.write(",")
|
||
f.write("\n")
|
||
f.write("}\n")
|
||
|
||
print(f"头文件创建完成: {hdr_path}")
|
||
|
||
def evaluate_model(self, X_test, y_test):
|
||
"""
|
||
评估模型性能
|
||
|
||
Args:
|
||
X_test: 测试特征
|
||
y_test: 测试标签
|
||
|
||
Returns:
|
||
dict: 评估指标
|
||
"""
|
||
if self.model is None:
|
||
raise ValueError("模型未训练")
|
||
|
||
# 预测
|
||
y_pred = self.model.predict(X_test)
|
||
|
||
# 计算各项指标
|
||
accuracy = accuracy_score(y_test, y_pred)
|
||
report = classification_report(y_test, y_pred, target_names=self.label_encoder.classes_, output_dict=True)
|
||
conf_matrix = confusion_matrix(y_test, y_pred)
|
||
|
||
return {
|
||
'accuracy': accuracy,
|
||
'classification_report': report,
|
||
'confusion_matrix': conf_matrix,
|
||
'predictions': y_pred
|
||
}
|
||
|
||
|
||
def parse_model_params(param_str):
|
||
"""
|
||
解析模型参数字符串
|
||
|
||
Args:
|
||
param_str (str): 参数字符串,如 "C=1.0,kernel=rbf" 或 "{'C': 1.0, 'kernel': 'rbf'}"
|
||
|
||
Returns:
|
||
dict: 解析后的参数字典
|
||
"""
|
||
if not param_str:
|
||
return {}
|
||
|
||
try:
|
||
# 尝试作为JSON解析
|
||
if param_str.strip().startswith('{'):
|
||
return json.loads(param_str)
|
||
# 尝试作为Python字典字面量解析
|
||
elif '=' in param_str:
|
||
# 解析 key=value,key=value 格式
|
||
params = {}
|
||
pairs = param_str.split(',')
|
||
for pair in pairs:
|
||
if '=' not in pair:
|
||
continue
|
||
key, value = pair.split('=', 1)
|
||
key = key.strip()
|
||
value = value.strip()
|
||
|
||
# 尝试转换数值类型
|
||
if value.lower() in ('true', 'false'):
|
||
value = value.lower() == 'true'
|
||
elif value.replace('.', '').replace('-', '').isdigit():
|
||
if '.' in value:
|
||
value = float(value)
|
||
else:
|
||
value = int(value)
|
||
elif value.startswith('[') and value.endswith(']'):
|
||
# 列表解析
|
||
value = ast.literal_eval(value)
|
||
elif value.startswith('(') and value.endswith(')'):
|
||
# 元组解析
|
||
value = ast.literal_eval(value)
|
||
|
||
params[key] = value
|
||
return params
|
||
else:
|
||
# 尝试直接解析为Python字典
|
||
return ast.literal_eval(param_str)
|
||
except (json.JSONDecodeError, ValueError, SyntaxError) as e:
|
||
print(f"警告: 参数解析失败 '{param_str}': {e}")
|
||
print("使用默认参数")
|
||
return {}
|
||
|
||
|
||
def show_model_params_help():
|
||
"""显示模型参数帮助信息"""
|
||
print("\n模型参数使用说明:")
|
||
print("-" * 50)
|
||
print("格式1 (key=value): C=1.0,kernel=rbf,n_estimators=100")
|
||
print("格式2 (JSON): {\"C\": 1.0, \"kernel\": \"rbf\", \"n_estimators\": 100}")
|
||
print("格式3 (Python): {'C': 1.0, 'kernel': 'rbf', 'n_estimators': 100}")
|
||
print()
|
||
|
||
print("常用模型参数:")
|
||
print("- SVM: C, kernel, gamma, degree")
|
||
print("- Random Forest: n_estimators, max_depth, min_samples_split, min_samples_leaf")
|
||
print("- Logistic Regression: C, penalty, solver, max_iter")
|
||
print("- KNN: n_neighbors, weights, metric, p")
|
||
print("- Linear Discriminant: solver, shrinkage, n_components")
|
||
print("- Quadratic Discriminant: reg_param, store_covariance")
|
||
print("- PLS-DA: n_components")
|
||
print("- Spectral Unmix: n_endmembers")
|
||
print("- Euclidean/Mahalanobis/Spectral Angle: 无参数")
|
||
print()
|
||
|
||
|
||
def list_supported_models():
|
||
"""列出所有支持的模型"""
|
||
classifier = HyperspectralClassifier()
|
||
models = list(classifier.classifiers.keys())
|
||
|
||
print("支持的分类模型:")
|
||
print("-" * 40)
|
||
for i, model in enumerate(models, 1):
|
||
print(f"{i:2}. {model}")
|
||
print("-" * 40)
|
||
print()
|
||
|
||
return models
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='高光谱图像分类工具')
|
||
parser.add_argument('--hyperspectral', required=True, help='高光谱文件路径')
|
||
parser.add_argument('--roi', required=True, help='ROI XML文件路径')
|
||
parser.add_argument('--model', default='svm', help='分类模型类型')
|
||
parser.add_argument('--output', required=True, help='输出文件路径')
|
||
parser.add_argument('--no-grid-search', action='store_true',
|
||
help='禁用自动超参数调优')
|
||
parser.add_argument('--model-params', type=str, default='',
|
||
help='模型参数 (格式: key=value,key=value 或 JSON)')
|
||
parser.add_argument('--list-models', action='store_true',
|
||
help='列出所有支持的模型并退出')
|
||
parser.add_argument('--help-params', action='store_true',
|
||
help='显示模型参数使用帮助并退出')
|
||
parser.add_argument('--batch-size', type=int, default=10000,
|
||
help='预测时的批处理大小 (默认: 10000)')
|
||
parser.add_argument('--no-standardization', action='store_true',
|
||
help='禁用数据标准化')
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.list_models:
|
||
list_supported_models()
|
||
return
|
||
|
||
if args.help_params:
|
||
show_model_params_help()
|
||
return
|
||
|
||
# 解析模型参数
|
||
model_params = parse_model_params(args.model_params)
|
||
|
||
# 创建配置对象
|
||
config = ClassificationConfig(
|
||
hyperspectral_path=args.hyperspectral,
|
||
roi_path=args.roi,
|
||
output_path=args.output,
|
||
model_type=args.model,
|
||
use_grid_search=not args.no_grid_search,
|
||
model_params=model_params,
|
||
batch_size=args.batch_size,
|
||
use_standardization=not args.no_standardization
|
||
)
|
||
|
||
# 创建并运行分类器
|
||
classifier = ConfigurableHyperspectralClassifier(config)
|
||
classifier.run_classification()
|
||
|
||
|
||
# 便捷函数,用于简化使用
|
||
def classify_hyperspectral_image(config: ClassificationConfig) -> None:
|
||
"""
|
||
使用配置对象进行高光谱图像分类的主要接口函数
|
||
|
||
Args:
|
||
config: 分类配置对象
|
||
"""
|
||
classifier = ConfigurableHyperspectralClassifier(config)
|
||
classifier.run_classification()
|
||
|
||
|
||
def classify_with_params(hyperspectral_path: str, roi_path: str, output_path: str,
|
||
model_type: str = 'svm', use_grid_search: bool = True,
|
||
model_params: Optional[Dict[str, Any]] = None,
|
||
batch_size: int = 10000, use_standardization: bool = True) -> None:
|
||
"""
|
||
使用参数进行高光谱图像分类的便捷函数
|
||
|
||
Args:
|
||
hyperspectral_path: 高光谱文件路径
|
||
roi_path: ROI文件路径
|
||
output_path: 输出文件路径
|
||
model_type: 模型类型
|
||
use_grid_search: 是否使用网格搜索
|
||
model_params: 模型参数字典
|
||
batch_size: 批处理大小
|
||
use_standardization: 是否使用标准化
|
||
"""
|
||
config = ClassificationConfig(
|
||
hyperspectral_path=hyperspectral_path,
|
||
roi_path=roi_path,
|
||
output_path=output_path,
|
||
model_type=model_type,
|
||
use_grid_search=use_grid_search,
|
||
model_params=model_params or {},
|
||
batch_size=batch_size,
|
||
use_standardization=use_standardization
|
||
)
|
||
|
||
classify_hyperspectral_image(config)
|
||
|
||
|
||
def batch_classify_all_methods():
|
||
"""批量测试所有可用的分类方法"""
|
||
print("="*60)
|
||
print("批量测试所有分类方法")
|
||
print("="*60)
|
||
|
||
# 检查输入文件是否存在
|
||
hyperspectral = r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip"
|
||
roi = r"E:\code\spectronon\single_classsfication\data\roi.xml"
|
||
output_dir = r"E:\code\spectronon\single_classsfication\output"
|
||
|
||
# 确保输出目录存在
|
||
import os
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
if not os.path.exists(hyperspectral):
|
||
print(f"错误: 高光谱文件不存在: {hyperspectral}")
|
||
return
|
||
|
||
if not os.path.exists(roi):
|
||
print(f"错误: ROI文件不存在: {roi}")
|
||
return
|
||
|
||
# 定义所有可用的分类器
|
||
available_classifiers = [
|
||
'euclidean', 'mahalanobis', 'linear_discriminant', 'quadratic_discriminant',
|
||
'plsda', 'spectral_angle_mapper', 'spectral_unmix',
|
||
'random_forest', 'svm', 'logistic_regression', 'knn', 'adaboost'
|
||
]
|
||
|
||
# 条件添加梯度提升模型
|
||
if XGBOOST_AVAILABLE:
|
||
available_classifiers.append('xgboost')
|
||
if LIGHTGBM_AVAILABLE:
|
||
available_classifiers.append('lightgbm')
|
||
if CATBOOST_AVAILABLE:
|
||
available_classifiers.append('catboost')
|
||
|
||
results = {}
|
||
timing_results = {}
|
||
|
||
print(f"将测试 {len(available_classifiers)} 种分类方法:")
|
||
for i, method in enumerate(available_classifiers, 1):
|
||
print(f"{i:2d}. {method}")
|
||
print()
|
||
|
||
for i, method in enumerate(available_classifiers, 1):
|
||
start_time = time.time()
|
||
print("-" * 50)
|
||
print(f"[{i}/{len(available_classifiers)}] 测试方法: {method.upper()}")
|
||
print("-" * 50)
|
||
|
||
try:
|
||
# 为每种方法设置不同的输出路径
|
||
output_path = os.path.join(output_dir, f'classification_{method}.bil')
|
||
|
||
print(f"输入文件: {os.path.basename(hyperspectral)}")
|
||
print(f"ROI文件: {os.path.basename(roi)}")
|
||
print(f"输出文件: {os.path.basename(output_path)}")
|
||
print(f"模型类型: {method}")
|
||
|
||
# 临时切换到输出目录,确保混淆矩阵和CSV文件保存在正确位置
|
||
original_cwd = os.getcwd()
|
||
os.chdir(output_dir)
|
||
|
||
try:
|
||
# 使用便捷函数进行分类
|
||
classify_with_params(
|
||
hyperspectral_path=hyperspectral,
|
||
roi_path=roi,
|
||
output_path=output_path,
|
||
model_type=method,
|
||
use_grid_search=False, # 为了速度,关闭网格搜索
|
||
model_params={}, # 使用默认参数
|
||
batch_size=5000 # 减小批处理大小以节省内存
|
||
)
|
||
finally:
|
||
# 恢复原始工作目录
|
||
os.chdir(original_cwd)
|
||
|
||
end_time = time.time()
|
||
elapsed_time = end_time - start_time
|
||
timing_results[method] = elapsed_time
|
||
|
||
results[method] = "成功"
|
||
print(f"✅ {method} 分类完成,结果保存至: {output_path}")
|
||
print(f"⏱️ 耗时: {elapsed_time:.2f}秒")
|
||
|
||
except Exception as e:
|
||
end_time = time.time()
|
||
elapsed_time = end_time - start_time
|
||
timing_results[method] = elapsed_time
|
||
|
||
error_msg = str(e)
|
||
results[method] = f"失败: {error_msg}"
|
||
print(f"❌ {method} 分类失败: {error_msg}")
|
||
print(f"⏱️ 耗时: {elapsed_time:.2f}秒")
|
||
# 继续下一个方法,不中断整个流程
|
||
continue
|
||
|
||
print() # 空行分隔不同方法的结果
|
||
|
||
# 输出总结报告
|
||
print("\n" + "="*60)
|
||
print("批量分类测试完成报告")
|
||
print("="*60)
|
||
|
||
successful = sum(1 for result in results.values() if result == "成功")
|
||
total = len(results)
|
||
|
||
print(f"总方法数: {total}")
|
||
print(f"成功方法数: {successful}")
|
||
print(f"失败方法数: {total - successful}")
|
||
print(".1f")
|
||
|
||
print("\n详细结果:")
|
||
for method, result in results.items():
|
||
status = "✅" if result == "成功" else "❌"
|
||
print(f"{status} {method}: {result}")
|
||
|
||
# 保存结果到文件
|
||
try:
|
||
report_path = os.path.join(output_dir, 'batch_classification_report.txt')
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
f.write("批量分类测试报告\n")
|
||
f.write("="*60 + "\n")
|
||
f.write(f"测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||
f.write(f"输入文件: {hyperspectral}\n")
|
||
f.write(f"ROI文件: {roi}\n")
|
||
f.write(f"输出目录: {output_dir}\n\n")
|
||
|
||
f.write(f"统计信息:\n")
|
||
f.write(f"- 总方法数: {total}\n")
|
||
f.write(f"- 成功方法数: {successful}\n")
|
||
f.write(f"- 失败方法数: {total - successful}\n")
|
||
f.write(f"- 成功率: {successful/total*100:.1f}%\n\n")
|
||
|
||
f.write("详细结果:\n")
|
||
f.write("-" * 40 + "\n")
|
||
|
||
# 成功的和失败的分开显示
|
||
successful_methods = [(m, r) for m, r in results.items() if r == "成功"]
|
||
failed_methods = [(m, r) for m, r in results.items() if r != "成功"]
|
||
|
||
if successful_methods:
|
||
f.write("✅ 成功的方法:\n")
|
||
for method, result in successful_methods:
|
||
elapsed = timing_results.get(method, 0)
|
||
f.write(f" • {method}: {result}\n")
|
||
f.write(f" 输出文件: classification_{method}.bil\n")
|
||
f.write(f" 处理时间: {elapsed:.2f}秒\n")
|
||
f.write("\n")
|
||
|
||
# 显示性能统计
|
||
if timing_results:
|
||
successful_times = [timing_results[m] for m, _ in successful_methods]
|
||
f.write("性能统计:\n")
|
||
f.write(f" • 最快方法: {min(successful_times):.2f}秒\n")
|
||
f.write(f" • 最慢方法: {max(successful_times):.2f}秒\n")
|
||
f.write(f" • 平均时间: {sum(successful_times)/len(successful_times):.2f}秒\n")
|
||
f.write("\n")
|
||
|
||
if failed_methods:
|
||
f.write("❌ 失败的方法:\n")
|
||
for method, result in failed_methods:
|
||
elapsed = timing_results.get(method, 0)
|
||
f.write(f" • {method}: {result}\n")
|
||
f.write(f" 处理时间: {elapsed:.2f}秒\n")
|
||
f.write("\n")
|
||
|
||
f.write("算法说明:\n")
|
||
f.write("-" * 40 + "\n")
|
||
f.write("传统机器学习:\n")
|
||
f.write(" • svm: 支持向量机\n")
|
||
f.write(" • random_forest: 随机森林\n")
|
||
f.write(" • logistic_regression: 逻辑回归\n")
|
||
f.write(" • knn: K近邻\n")
|
||
f.write(" • adaboost: AdaBoost\n")
|
||
f.write(" • linear_discriminant: 线性判别分析\n")
|
||
f.write(" • quadratic_discriminant: 二次判别分析\n\n")
|
||
|
||
f.write("高光谱专用算法:\n")
|
||
f.write(" • euclidean: 欧几里得距离分类\n")
|
||
f.write(" • mahalanobis: 马哈拉诺比斯距离分类\n")
|
||
f.write(" • spectral_angle_mapper: 光谱角度映射\n")
|
||
f.write(" • spectral_unmix: 光谱解混\n")
|
||
f.write(" • plsda: 偏最小二乘判别分析\n\n")
|
||
|
||
f.write("梯度提升算法:\n")
|
||
f.write(" • xgboost: XGBoost\n")
|
||
f.write(" • lightgbm: LightGBM\n")
|
||
f.write(" • catboost: CatBoost\n")
|
||
|
||
print(f"\n📄 详细报告已保存至: {report_path}")
|
||
|
||
# 显示成功的输出文件
|
||
if successful_methods:
|
||
print("\n📁 生成的分类结果文件:")
|
||
for method, _ in successful_methods:
|
||
bil_file = os.path.join(output_dir, f'classification_{method}.bil')
|
||
png_file = f'confusion_matrix_{method}.png'
|
||
csv_file = f'model_metrics_{method}.csv'
|
||
|
||
if os.path.exists(bil_file):
|
||
print(f" • 分类结果: {bil_file}")
|
||
if os.path.exists(png_file):
|
||
print(f" • 混淆矩阵: {png_file}")
|
||
if os.path.exists(csv_file):
|
||
print(f" • 评价指标: {csv_file}")
|
||
print()
|
||
|
||
except Exception as e:
|
||
print(f"保存报告失败: {e}")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 批量测试所有分类方法
|
||
batch_classify_all_methods()
|
||
|
||
# 单个方法测试示例(注释掉)
|
||
"""
|
||
# 示例用法 - 使用单个模型
|
||
hyperspectral = r"D:\resonon\Intro_to_data_processing\GreenPaintChipsSmall.bip"
|
||
roi = r"E:\code\spectronon\single_classsfication\data\roi.xml"
|
||
model = 'svm' # 可以改为其他模型
|
||
output = r'E:\code\spectronon\single_classsfication\output\test.bil'
|
||
|
||
try:
|
||
classify_with_params(
|
||
hyperspectral_path=hyperspectral,
|
||
roi_path=roi,
|
||
output_path=output,
|
||
model_type=model,
|
||
use_grid_search=False,
|
||
model_params={}
|
||
)
|
||
except Exception as e:
|
||
print(f"错误: {str(e)}")
|
||
sys.exit(1)
|
||
""" |