Files
HSI/registry.py

2318 lines
110 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
高光谱分析工具包统一注册表
提供所有功能模块的注册信息和参数构建器
"""
import importlib
from typing import Dict, Any, Callable, List, Optional, Union
from dataclasses import dataclass
import argparse
@dataclass
class ModuleInfo:
"""模块信息类"""
category: str
module_path: str
callable_name: str
config_class: Optional[str] = None
description: str = ""
requires_data_file: bool = True
requires_roi_file: bool = False
output_types: List[str] = None
def __post_init__(self):
if self.output_types is None:
self.output_types = ["file", "summary"]
class ArgsBuilder:
"""参数构建器基类"""
def add_common_args(self, subparser: argparse.ArgumentParser):
"""添加通用参数"""
subparser.add_argument('--input', '-i', required=True,
help='输入文件路径 (ENVI格式: .hdr/.dat 或 CSV)')
subparser.add_argument('--output-dir', '-o', default='./results',
help='输出目录路径 (默认: ./results)')
subparser.add_argument('--output-prefix', default='result',
help='输出文件前缀 (默认: result)')
subparser.add_argument('--output-file', '-f',
help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)')
def add_module_specific_args(self, subparser: argparse.ArgumentParser):
"""子类实现:添加模块特定参数"""
pass
def build_config(self, args) -> Any:
"""子类实现:构建配置对象"""
pass
# 各模块的参数构建器
class DimensionalityReductionArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
# CSV文件参数
subparser.add_argument('--label-col', '-l', help='CSV文件的标签列名或索引')
subparser.add_argument('--spectral-start', '-s', help='CSV文件的谱段起始列名或索引')
subparser.add_argument('--spectral-end', '-e', help='CSV文件的谱段结束列名或索引')
subparser.add_argument('--method', '-m', default='pca',
choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'],
help='降维方法')
subparser.add_argument('--n-components', '-n', type=int, default=3,
help='降维后的维度数')
subparser.add_argument('--batch-methods', nargs='+',
choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'],
help='批量处理的降维方法列表')
subparser.add_argument('--batch-components', nargs='+', type=int,
help='对应批量方法的维度数列表')
# 数据预处理参数
subparser.add_argument('--use-standardization', action='store_true', default=True,
help='是否使用数据标准化')
subparser.add_argument('--generate-plots', action='store_true', default=True,
help='是否生成可视化图表')
subparser.add_argument('--color-by-label', action='store_true', default=True,
help='是否按标签着色散点图')
# PCA专用参数
subparser.add_argument('--pca-whiten', action='store_true',
help='是否白化数据 (PCA方法)')
subparser.add_argument('--pca-svd-solver', choices=['auto', 'full', 'arpack', 'randomized'], default='auto',
help='SVD求解器 (PCA方法)')
# ICA专用参数
subparser.add_argument('--ica-algorithm', choices=['parallel', 'deflation'], default='parallel',
help='ICA算法 (ICA方法)')
subparser.add_argument('--ica-fun', choices=['logcosh', 'exp', 'cube'], default='logcosh',
help='非线性函数 (ICA方法)')
subparser.add_argument('--ica-max-iter', type=int, default=200,
help='最大迭代次数 (ICA方法)')
# FA专用参数
subparser.add_argument('--fa-max-iter', type=int, default=1000,
help='最大迭代次数 (FA方法)')
subparser.add_argument('--fa-tol', type=float, default=1e-2,
help='收敛容差 (FA方法)')
# MDS专用参数
subparser.add_argument('--mds-metric', action='store_true', default=True,
help='是否使用度量MDS (MDS方法)')
subparser.add_argument('--mds-max-iter', type=int, default=300,
help='最大迭代次数 (MDS方法)')
subparser.add_argument('--mds-n-init', type=int, default=4,
help='初始化次数 (MDS方法)')
# Isomap专用参数
subparser.add_argument('--isomap-n-neighbors', type=int, default=5,
help='邻居数量 (Isomap方法)')
subparser.add_argument('--isomap-eigen-solver', choices=['auto', 'arpack', 'dense'], default='auto',
help='特征值求解器 (Isomap方法)')
# LLE专用参数
subparser.add_argument('--lle-n-neighbors', type=int, default=5,
help='邻居数量 (LLE方法)')
subparser.add_argument('--lle-method', choices=['standard', 'hessian', 'modified', 'ltsa'], default='standard',
help='LLE方法类型 (LLE方法)')
# t-SNE专用参数
subparser.add_argument('--tsne-perplexity', type=float, default=30.0,
help='困惑度 (t-SNE方法)')
subparser.add_argument('--tsne-early-exaggeration', type=float, default=12.0,
help='早期夸张系数 (t-SNE方法)')
subparser.add_argument('--tsne-learning-rate', type=float, default=200.0,
help='学习率 (t-SNE方法)')
subparser.add_argument('--tsne-max-iter', type=int, default=1000,
help='最大迭代次数 (t-SNE方法)')
def build_config(self, args):
from Dimensionality_Reduction_method.dimensionality_reduction import DimensionalityReductionConfig
# 检查是否是批量处理
is_batch = hasattr(args, 'batch_methods') and args.batch_methods is not None
# 构建方法参数
method_params = {}
if not is_batch and hasattr(args, 'method'):
if args.method == 'pca':
if hasattr(args, 'pca_whiten'):
method_params['whiten'] = args.pca_whiten
if hasattr(args, 'pca_svd_solver'):
method_params['svd_solver'] = args.pca_svd_solver
elif args.method == 'ica':
if hasattr(args, 'ica_algorithm'):
method_params['algorithm'] = args.ica_algorithm
if hasattr(args, 'ica_fun'):
method_params['fun'] = args.ica_fun
if hasattr(args, 'ica_max_iter'):
method_params['max_iter'] = args.ica_max_iter
elif args.method == 'fa':
if hasattr(args, 'fa_max_iter'):
method_params['max_iter'] = args.fa_max_iter
if hasattr(args, 'fa_tol'):
method_params['tol'] = args.fa_tol
elif args.method == 'mds':
if hasattr(args, 'mds_metric'):
method_params['metric'] = args.mds_metric
if hasattr(args, 'mds_max_iter'):
method_params['max_iter'] = args.mds_max_iter
if hasattr(args, 'mds_n_init'):
method_params['n_init'] = args.mds_n_init
elif args.method == 'isomap':
if hasattr(args, 'isomap_n_neighbors'):
method_params['n_neighbors'] = args.isomap_n_neighbors
if hasattr(args, 'isomap_eigen_solver'):
method_params['eigen_solver'] = args.isomap_eigen_solver
elif args.method == 'lle':
if hasattr(args, 'lle_n_neighbors'):
method_params['n_neighbors'] = args.lle_n_neighbors
if hasattr(args, 'lle_method'):
method_params['method'] = args.lle_method
elif args.method == 't-sne':
if hasattr(args, 'tsne_perplexity'):
method_params['perplexity'] = args.tsne_perplexity
if hasattr(args, 'tsne_early_exaggeration'):
method_params['early_exaggeration'] = args.tsne_early_exaggeration
if hasattr(args, 'tsne_learning_rate'):
method_params['learning_rate'] = args.tsne_learning_rate
if hasattr(args, 'tsne_max_iter'):
method_params['max_iter'] = args.tsne_max_iter
# 构建配置参数
config_kwargs = {
'input_path': args.input,
'output_dir': args.output_dir,
'label_col': getattr(args, 'label_col', None),
'spectral_start': getattr(args, 'spectral_start', None),
'spectral_end': getattr(args, 'spectral_end', None),
'method_params': method_params,
'use_standardization': getattr(args, 'use_standardization', True),
'generate_plots': getattr(args, 'generate_plots', True),
'color_by_label': getattr(args, 'color_by_label', True),
'save_plots': getattr(args, 'save_plots', True)
}
if is_batch:
# 批量处理
if not hasattr(args, 'batch_components') or not args.batch_components:
batch_components = [3] * len(args.batch_methods)
else:
batch_components = args.batch_components
if len(batch_components) != len(args.batch_methods):
raise ValueError(f"批量方法数量 ({len(args.batch_methods)}) 与维度数量 ({len(batch_components)}) 不匹配")
config_kwargs.update({
'batch_methods': args.batch_methods,
'batch_components': batch_components
})
else:
# 单方法处理
config_kwargs.update({
'method': args.method,
'n_components': args.n_components
})
return DimensionalityReductionConfig(**config_kwargs)
class SegmentationArgsBuilder(ArgsBuilder):
def add_common_args(self, subparser):
"""重写通用参数,设置分割专用的默认输出目录"""
subparser.add_argument('--input', '-i', required=True,
help='输入文件路径 (ENVI格式: .hdr/.dat 或图像文件)')
subparser.add_argument('--output-dir', '-o', default='./segmentation_results',
help='输出目录路径 (默认: ./segmentation_results)')
subparser.add_argument('--output-prefix', default='result',
help='输出文件前缀 (默认: result)')
def add_module_specific_args(self, subparser):
# 基本分割参数
subparser.add_argument('--method', '-m', default='otsu',
choices=['fixed', 'bimodal', 'iterative', 'otsu', 'isodata', 'adaptive'],
help='分割方法 (默认: otsu)')
subparser.add_argument('--band-index', '-b', type=int, default=0,
help='使用的波段索引 (默认: 0)')
# 阈值参数
subparser.add_argument('--threshold', '-t', type=float,
help='固定阈值 (仅用于fixed方法)')
# 批量处理参数
subparser.add_argument('--batch-methods', nargs='+',
choices=['fixed', 'bimodal', 'iterative', 'otsu', 'isodata', 'adaptive'],
help='批量处理的分割方法列表')
subparser.add_argument('--batch-bands', type=int, nargs='+',
help='批量处理的波段索引列表')
# 自适应阈值参数
subparser.add_argument('--adaptive-block-size', type=int, default=11,
help='自适应阈值块大小,必须为奇数 (默认: 11)')
subparser.add_argument('--adaptive-c', type=float, default=2.0,
help='自适应阈值常数 (默认: 2.0)')
subparser.add_argument('--adaptive-method', choices=['gaussian', 'mean'], default='gaussian',
help='自适应阈值方法 (默认: gaussian)')
# 迭代法参数
subparser.add_argument('--max-iterations', type=int, default=100,
help='迭代法最大迭代次数 (默认: 100)')
subparser.add_argument('--convergence-threshold', type=float, default=0.01,
help='迭代法收敛阈值 (默认: 0.01)')
# 直方图双峰法参数
subparser.add_argument('--histogram-bins', type=int, default=256,
help='直方图bin数量 (默认: 256)')
subparser.add_argument('--peak-min-distance', type=int, default=10,
help='峰值最小距离 (默认: 10)')
# Otsu参数
subparser.add_argument('--otsu-normalize', action='store_true', default=True,
help='Otsu方法是否归一化图像到0-255范围 (默认: True)')
def build_config(self, args):
from segment_method.threshold_Segment import ThresholdSegmentationConfig
# 处理批量参数
batch_methods = getattr(args, 'batch_methods', None)
batch_bands = getattr(args, 'batch_bands', None)
if batch_methods is None:
batch_methods = [args.method]
if batch_bands is None:
batch_bands = [args.band_index]
return ThresholdSegmentationConfig(
input_path=args.input,
output_dir=args.output_dir,
method=args.method,
band_index=args.band_index,
threshold_value=getattr(args, 'threshold', None),
batch_methods=batch_methods,
batch_bands=batch_bands,
adaptive_block_size=getattr(args, 'adaptive_block_size', 11),
adaptive_c=getattr(args, 'adaptive_c', 2.0),
adaptive_method=getattr(args, 'adaptive_method', 'gaussian'),
max_iterations=getattr(args, 'max_iterations', 100),
convergence_threshold=getattr(args, 'convergence_threshold', 0.01),
histogram_bins=getattr(args, 'histogram_bins', 256),
peak_min_distance=getattr(args, 'peak_min_distance', 10),
otsu_normalize=getattr(args, 'otsu_normalize', True)
)
class EdgeDetectionArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
# 基本参数
subparser.add_argument('--method', '-m', default='canny',
choices=['sobel', 'scharr', 'laplacian', 'log', 'canny'],
help='边缘检测方法')
subparser.add_argument('--edge-band-index', type=int, default=0,
help='使用的波段索引')
subparser.add_argument('--batch', action='store_true',
help='批量运行多种方法')
subparser.add_argument('--batch-bands', nargs='+', type=int,
help='批量处理的波段索引列表')
# Sobel和Scharr参数
subparser.add_argument('--sobel-dx', type=int, default=1, choices=[0, 1],
help='Sobel x方向导数阶数必须与sobel-dy之和等于1 (默认: 1)')
subparser.add_argument('--sobel-dy', type=int, default=0, choices=[0, 1],
help='Sobel y方向导数阶数必须与sobel-dx之和等于1 (默认: 0)')
subparser.add_argument('--sobel-ksize', type=int, default=-1,
help='Sobel核大小-1表示3x3 (默认: -1)')
# Laplacian参数
subparser.add_argument('--laplacian-ksize', type=int, default=1,
help='Laplacian核大小 (默认: 1)')
subparser.add_argument('--laplacian-scale', type=float, default=1.0,
help='Laplacian尺度参数 (默认: 1.0)')
subparser.add_argument('--laplacian-delta', type=float, default=0.0,
help='Laplacian delta参数 (默认: 0.0)')
# LoG参数
subparser.add_argument('--log-sigma', type=float, default=1.0,
help='LoG高斯标准差 (默认: 1.0)')
subparser.add_argument('--log-threshold', type=float, default=0.0,
help='LoG阈值0表示自动阈值 (默认: 0.0)')
subparser.add_argument('--log-mode', default='reflect',
choices=['reflect', 'constant', 'nearest', 'mirror', 'wrap'],
help='LoG边界模式 (默认: reflect)')
# 梯度算法阈值参数
subparser.add_argument('--gradient-threshold', type=float, default=0.0,
help='梯度阈值 (0-255)0表示不使用阈值 (默认: 0.0)')
subparser.add_argument('--gradient-use-otsu', action='store_true',
help='对梯度图使用Otsu自动阈值')
# Canny参数
subparser.add_argument('--canny-min-threshold', type=int, default=100,
help='Canny最小阈值 (默认: 100)')
subparser.add_argument('--canny-max-threshold', type=int, default=200,
help='Canny最大阈值 (默认: 200)')
subparser.add_argument('--canny-aperture-size', type=int, default=3, choices=[3, 5, 7],
help='Canny Sobel算子孔径大小 (默认: 3)')
subparser.add_argument('--canny-l2gradient', action='store_true',
help='Canny使用L2范数计算梯度')
# 通用参数
subparser.add_argument('--normalize-output', action='store_true',
help='将输出归一化到0-255范围')
subparser.add_argument('--use-blur', action='store_true',
help='预先进行高斯模糊')
subparser.add_argument('--blur-kernel-size', type=int, default=3,
help='模糊核大小,必须是奇数 (默认: 3)')
subparser.add_argument('--blur-sigma', type=float, default=1.0,
help='模糊标准差 (默认: 1.0)')
def build_config(self, args):
from edge_detect_method.edge_detect import EdgeDetectionConfig
# 确定输出路径
if hasattr(args, 'output_file') and args.output_file:
output_path = args.output_file
output_dir = None
else:
output_path = None
output_dir = args.output_dir
# 构建批量处理参数
batch_methods = [args.method]
batch_bands = [args.edge_band_index]
if hasattr(args, 'batch') and args.batch:
# 如果是批量模式,使用默认的多种方法
batch_methods = ['sobel', 'scharr', 'laplacian', 'log', 'canny']
if hasattr(args, 'batch_bands') and args.batch_bands:
if len(args.batch_bands) == len(batch_methods):
batch_bands = args.batch_bands
else:
batch_bands = args.batch_bands * (len(batch_methods) // len(args.batch_bands) + 1)
batch_bands = batch_bands[:len(batch_methods)]
else:
batch_bands = [args.edge_band_index] * len(batch_methods)
return EdgeDetectionConfig(
input_path=args.input,
method=args.method,
band_index=args.edge_band_index,
output_path=output_path,
output_dir=output_dir,
batch_methods=batch_methods,
batch_bands=batch_bands,
# Sobel参数
sobel_dx=getattr(args, 'sobel_dx', 1),
sobel_dy=getattr(args, 'sobel_dy', 1),
sobel_ksize=getattr(args, 'sobel_ksize', -1),
# Laplacian参数
laplacian_ksize=getattr(args, 'laplacian_ksize', 1),
laplacian_scale=getattr(args, 'laplacian_scale', 1.0),
laplacian_delta=getattr(args, 'laplacian_delta', 0.0),
# LoG参数
log_sigma=getattr(args, 'log_sigma', 1.0),
log_threshold=getattr(args, 'log_threshold', 0.0),
log_mode=getattr(args, 'log_mode', 'reflect'),
# 梯度参数
gradient_threshold=getattr(args, 'gradient_threshold', 0.0),
gradient_use_otsu=getattr(args, 'gradient_use_otsu', False),
# Canny参数
canny_min_threshold=getattr(args, 'canny_min_threshold', 100),
canny_max_threshold=getattr(args, 'canny_max_threshold', 200),
canny_aperture_size=getattr(args, 'canny_aperture_size', 3),
canny_l2gradient=getattr(args, 'canny_l2gradient', False),
# 通用参数
normalize_output=getattr(args, 'normalize_output', False),
use_blur=getattr(args, 'use_blur', False),
blur_kernel_size=getattr(args, 'blur_kernel_size', 3),
blur_sigma=getattr(args, 'blur_sigma', 1.0)
)
class AnomalyDetectionArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--method', '-m', default='covariance',
choices=['covariance', 'one-class-svm', 'rx', 'squared-loss-probability'],
help='异常检测方法')
# 通用参数(所有方法都可能用到)
subparser.add_argument('--contamination', '-c', type=float, default=0.1,
help='异常样本比例 (用于自动确定阈值)')
# RX方法专用参数
subparser.add_argument('--background', help='背景数据路径 (RX方法custom模式时必需)')
subparser.add_argument('--label-col', help='标签列名 (CSV文件)')
subparser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
subparser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
subparser.add_argument('--background-model', choices=['global', 'local_row', 'local_square', 'custom'],
default='global', help='背景模型类型 (RX方法默认global)')
subparser.add_argument('--window-size', type=int, default=5,
help='局部窗口大小 (RX方法默认5)')
subparser.add_argument('--row-window', type=int, default=3,
help='行窗口大小 (RX方法默认3)')
subparser.add_argument('--rx-threshold', type=float,
help='异常检测阈值 (RX方法不指定则只输出距离图)')
subparser.add_argument('--robust-covariance', action='store_true',
help='使用稳健协方差估计 (RX方法)')
# One-Class SVM专用参数
subparser.add_argument('--ocsvm-kernel', choices=['linear', 'rbf', 'poly'], default='rbf',
help='SVM核函数 (One-Class SVM方法默认rbf)')
subparser.add_argument('--ocsvm-nu', type=float, default=0.1,
help='训练误差比例 (One-Class SVM方法默认0.1)')
subparser.add_argument('--ocsvm-gamma', type=float,
help='核函数参数gamma (One-Class SVM方法不指定则自动选择)')
subparser.add_argument('--ocsvm-degree', type=int, default=3,
help='多项式核的阶数 (One-Class SVM方法默认3)')
subparser.add_argument('--ocsvm-coef0', type=float, default=0.0,
help='核函数常数项 (One-Class SVM方法默认0.0)')
subparser.add_argument('--ocsvm-tol', type=float, default=1e-3,
help='停止准则公差 (One-Class SVM方法默认1e-3)')
subparser.add_argument('--ocsvm-shrinking', action='store_true', default=True,
help='是否使用shrinking heuristic (One-Class SVM方法默认True)')
subparser.add_argument('--ocsvm-cache-size', type=float, default=200,
help='核缓存大小(MB) (One-Class SVM方法默认200)')
subparser.add_argument('--ocsvm-max-iter', type=int, default=-1,
help='最大迭代次数 (One-Class SVM方法默认-1表示无限制)')
subparser.add_argument('--ocsvm-use-scaling', action='store_true', default=True,
help='是否使用标准化 (One-Class SVM方法默认True)')
subparser.add_argument('--ocsvm-use-pca', action='store_true', default=True,
help='是否使用PCA降维 (One-Class SVM方法默认True)')
subparser.add_argument('--ocsvm-n-components', type=int,
help='PCA降维维度 (One-Class SVM方法默认自动选择)')
subparser.add_argument('--ocsvm-use-grid-search', action='store_true',
help='是否使用网格搜索调优参数 (One-Class SVM方法默认False)')
subparser.add_argument('--ocsvm-cv-folds', type=int, default=3,
help='交叉验证折数 (One-Class SVM方法默认3)')
# Squared Loss Probability专用参数
subparser.add_argument('--reconstruction-method', choices=['linear', 'pca'], default='linear',
help='重构方法 (Squared Loss方法默认linear)')
subparser.add_argument('--slp-n-components', type=int,
help='PCA降维维度 (Squared Loss方法使用PCA时必需)')
subparser.add_argument('--target-band', type=int,
help='目标波段索引 (Squared Loss方法默认使用最后一个波段)')
subparser.add_argument('--slp-threshold', type=float, default=0.8,
help='分类阈值 (Squared Loss方法默认0.8)')
subparser.add_argument('--probability-mode', action='store_true', default=True,
help='使用概率模式 (Squared Loss方法默认True)')
subparser.add_argument('--slp-use-scaling', action='store_true', default=True,
help='是否使用标准化 (Squared Loss方法默认True)')
subparser.add_argument('--scaler-type', choices=['standard', 'robust'], default='robust',
help='标准化类型 (Squared Loss方法默认robust)')
subparser.add_argument('--fit-intercept', action='store_true', default=True,
help='是否拟合截距 (Squared Loss方法默认True)')
def build_config(self, args):
if args.method == 'covariance':
from Anomaly_method.Covariance import CovarianceAnomalyConfig
config_kwargs = {
'input_path': args.input,
'output_dir': args.output_dir,
'contamination': getattr(args, 'contamination', 0.1)
}
return CovarianceAnomalyConfig(**config_kwargs)
elif args.method == 'one-class-svm':
from Anomaly_method.One_Class_SVM import OneClassSVMConfig
# 获取默认配置的值(不创建完整的配置对象以避免验证)
config_kwargs = {
'input_path': args.input,
'output_dir': args.output_dir,
'kernel': getattr(args, 'ocsvm_kernel', 'rbf'),
'nu': getattr(args, 'ocsvm_nu', 0.1),
'gamma': getattr(args, 'ocsvm_gamma', None),
'degree': getattr(args, 'ocsvm_degree', 3),
'coef0': getattr(args, 'ocsvm_coef0', 0.0),
'tol': getattr(args, 'ocsvm_tol', 1e-3),
'shrinking': getattr(args, 'ocsvm_shrinking', True),
'cache_size': getattr(args, 'ocsvm_cache_size', 200.0),
'max_iter': getattr(args, 'ocsvm_max_iter', -1),
'use_scaling': getattr(args, 'ocsvm_use_scaling', True),
'use_pca': getattr(args, 'ocsvm_use_pca', True),
'n_components': getattr(args, 'ocsvm_n_components', None),
'use_grid_search': getattr(args, 'ocsvm_use_grid_search', False),
'cv_folds': getattr(args, 'ocsvm_cv_folds', 3)
}
return OneClassSVMConfig(**config_kwargs)
elif args.method == 'rx':
from Anomaly_method.RX import RXConfig
config_kwargs = {
'input_path': args.input,
'background_path': getattr(args, 'background', None),
'label_col': getattr(args, 'label_col', None),
'spectral_start': getattr(args, 'spectral_start', None),
'spectral_end': getattr(args, 'spectral_end', None),
'output_dir': args.output_dir,
'background_model': getattr(args, 'background_model', 'global'),
'window_size': getattr(args, 'window_size', 5),
'row_window': getattr(args, 'row_window', 3),
'threshold': getattr(args, 'rx_threshold', None),
'contamination': getattr(args, 'contamination', 0.1),
'use_robust_covariance': getattr(args, 'robust_covariance', False)
}
return RXConfig(**config_kwargs)
elif args.method == 'squared-loss-probability':
from Anomaly_method.squared_loss_probability import SquaredLossAnomalyConfig
config_kwargs = {
'input_path': args.input,
'output_dir': args.output_dir,
'reconstruction_method': getattr(args, 'reconstruction_method', 'linear'),
'n_components': getattr(args, 'slp_n_components', None),
'target_band': getattr(args, 'target_band', None),
'probability_flag': getattr(args, 'probability_mode', True),
'threshold': getattr(args, 'slp_threshold', 0.8),
'use_scaling': getattr(args, 'slp_use_scaling', True),
'scaler_type': getattr(args, 'scaler_type', 'robust'),
'fit_intercept': getattr(args, 'fit_intercept', True)
}
return SquaredLossAnomalyConfig(**config_kwargs)
else:
raise ValueError(f"不支持的异常检测方法: {args.method}")
class ClassificationArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--method', '-m', default='svm',
choices=['svm', 'random_forest', 'knn', 'logistic_regression', 'linear_discriminant', 'quadratic_discriminant', 'plsda', 'xgboost', 'lightgbm', 'catboost'],
help='分类方法')
subparser.add_argument('--roi-file', '-r', required=True,
help='ROI标注文件路径 (.xml)')
subparser.add_argument('--test-size', '-ts', type=float, default=0.3,
help='测试集比例')
# 数据预处理参数
subparser.add_argument('--use-standardization', action='store_true', default=True,
help='是否使用标准化')
subparser.add_argument('--batch-size', type=int, default=10000,
help='批处理大小')
# SVM专用参数
subparser.add_argument('--svm-c', type=float, default=1.0,
help='SVM正则化参数C (SVM方法)')
subparser.add_argument('--svm-kernel', choices=['linear', 'rbf', 'poly', 'sigmoid'], default='rbf',
help='SVM核函数 (SVM方法)')
subparser.add_argument('--svm-gamma',
help='SVM核函数参数gamma: "scale", "auto" 或数值 (SVM方法默认scale)')
# 随机森林专用参数
subparser.add_argument('--rf-n-estimators', type=int, default=100,
help='树的数量 (随机森林方法)')
subparser.add_argument('--rf-max-depth', type=int,
help='树的最大深度 (随机森林方法)')
subparser.add_argument('--rf-min-samples-split', type=int, default=2,
help='内部节点再划分所需最小样本数 (随机森林方法)')
# KNN专用参数
subparser.add_argument('--knn-n-neighbors', type=int, default=5,
help='邻居数量 (KNN方法)')
subparser.add_argument('--knn-weights', choices=['uniform', 'distance'], default='uniform',
help='权重函数 (KNN方法)')
subparser.add_argument('--knn-algorithm', choices=['auto', 'ball_tree', 'kd_tree', 'brute'], default='auto',
help='算法 (KNN方法)')
# 逻辑回归专用参数
subparser.add_argument('--lr-c', type=float, default=1.0,
help='正则化强度 (逻辑回归方法)')
subparser.add_argument('--lr-max-iter', type=int, default=100,
help='最大迭代次数 (逻辑回归方法)')
# PLS-DA专用参数
subparser.add_argument('--pls-n-components', type=int, default=10,
help='PLS成分数量 (PLS-DA方法)')
# XGBoost专用参数
subparser.add_argument('--xgb-n-estimators', type=int, default=100,
help='树的数量 (XGBoost方法)')
subparser.add_argument('--xgb-max-depth', type=int, default=6,
help='树的最大深度 (XGBoost方法)')
subparser.add_argument('--xgb-learning-rate', type=float, default=0.3,
help='学习率 (XGBoost方法)')
# LightGBM专用参数
subparser.add_argument('--lgb-n-estimators', type=int, default=100,
help='树的数量 (LightGBM方法)')
subparser.add_argument('--lgb-max-depth', type=int, default=-1,
help='树的最大深度 (LightGBM方法-1表示无限制)')
subparser.add_argument('--lgb-learning-rate', type=float, default=0.1,
help='学习率 (LightGBM方法)')
# CatBoost专用参数
subparser.add_argument('--cb-n-estimators', type=int, default=100,
help='树的数量 (CatBoost方法)')
subparser.add_argument('--cb-depth', type=int, default=6,
help='树的最大深度 (CatBoost方法)')
subparser.add_argument('--cb-learning-rate', type=float, default=0.03,
help='学习率 (CatBoost方法)')
def build_config(self, args):
from classfication_method.classfication import ClassificationConfig
model_params = {}
if args.method == 'svm':
model_params['C'] = getattr(args, 'svm_c', 1.0)
model_params['kernel'] = getattr(args, 'svm_kernel', 'rbf')
# 处理gamma参数如果未指定使用默认值'scale'
if hasattr(args, 'svm_gamma') and args.svm_gamma is not None:
# 如果提供了值,尝试转换为适当的类型
gamma_str = str(args.svm_gamma).lower()
if gamma_str in ['scale', 'auto']:
model_params['gamma'] = gamma_str
else:
try:
model_params['gamma'] = float(args.svm_gamma)
except ValueError:
model_params['gamma'] = 'scale' # 如果转换失败,使用默认值
else:
model_params['gamma'] = 'scale' # 默认值
elif args.method == 'rf':
model_params['n_estimators'] = getattr(args, 'rf_n_estimators', 100)
model_params['max_depth'] = getattr(args, 'rf_max_depth', None)
model_params['min_samples_split'] = getattr(args, 'rf_min_samples_split', 2)
elif args.method == 'knn':
model_params['n_neighbors'] = getattr(args, 'knn_n_neighbors', 5)
model_params['weights'] = getattr(args, 'knn_weights', 'uniform')
model_params['algorithm'] = getattr(args, 'knn_algorithm', 'auto')
elif args.method == 'lr':
model_params['C'] = getattr(args, 'lr_c', 1.0)
model_params['max_iter'] = getattr(args, 'lr_max_iter', 100)
elif args.method == 'pls-da':
model_params['n_components'] = getattr(args, 'pls_n_components', 10)
elif args.method == 'xgboost':
model_params['n_estimators'] = getattr(args, 'xgb_n_estimators', 100)
model_params['max_depth'] = getattr(args, 'xgb_max_depth', 6)
model_params['learning_rate'] = getattr(args, 'xgb_learning_rate', 0.3)
elif args.method == 'lightgbm':
model_params['n_estimators'] = getattr(args, 'lgb_n_estimators', 100)
model_params['max_depth'] = getattr(args, 'lgb_max_depth', -1)
model_params['learning_rate'] = getattr(args, 'lgb_learning_rate', 0.1)
elif args.method == 'catboost':
model_params['iterations'] = getattr(args, 'cb_n_estimators', 100)
model_params['depth'] = getattr(args, 'cb_depth', 6)
model_params['learning_rate'] = getattr(args, 'cb_learning_rate', 0.03)
# 确定输出路径
if hasattr(args, 'output_file') and args.output_file:
output_path = args.output_file
else:
# 如果没有指定output_file基于output_dir和output_prefix生成
import os
output_prefix = getattr(args, 'output_prefix', 'result')
output_path = os.path.join(args.output_dir, f"{output_prefix}.dat")
# 模型类型映射从命令行参数映射到ClassificationConfig期望的名称
model_type_mapping = {
'svm': 'svm',
'random_forest': 'random_forest',
'knn': 'knn',
'logistic_regression': 'logistic_regression',
'linear_discriminant': 'linear_discriminant',
'quadratic_discriminant': 'quadratic_discriminant',
'plsda': 'plsda',
'xgboost': 'xgboost',
'lightgbm': 'lightgbm',
'catboost': 'catboost'
}
config_model_type = model_type_mapping.get(args.method, args.method)
# 构建配置参数
config_kwargs = {
'hyperspectral_path': args.input,
'roi_path': args.roi_file,
'model_type': config_model_type,
'model_params': model_params,
'test_size': args.test_size,
'output_path': output_path,
'use_standardization': getattr(args, 'use_standardization', True),
'batch_size': getattr(args, 'batch_size', 10000)
}
return ClassificationConfig(**config_kwargs)
class ClusteringArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--method', '-m', default='kmeans',
choices=['kmeans', 'fuzzy-cmeans', 'gmm', 'hierarchical', 'dbscan', 'spectral', 'subspace', 'ensemble'],
help='聚类方法')
subparser.add_argument('--n-clusters', type=int, default=5,
help='聚类数量')
subparser.add_argument('--batch', action='store_true',
help='运行所有聚类方法')
# 数据预处理参数
subparser.add_argument('--use-scaling', action='store_true', default=True,
help='是否使用数据标准化')
subparser.add_argument('--scaler-type', choices=['standard', 'minmax'], default='standard',
help='标准化类型')
subparser.add_argument('--use-pca', action='store_true',
help='是否使用PCA降维')
subparser.add_argument('--pca-components', type=int,
help='PCA降维后的维度数')
# K-means专用参数
subparser.add_argument('--kmeans-init', choices=['k-means++', 'random'], default='k-means++',
help='初始化方法 (K-means方法)')
subparser.add_argument('--kmeans-n-init', type=int, default=10,
help='运行次数,选择最佳结果 (K-means方法)')
subparser.add_argument('--kmeans-max-iter', type=int, default=300,
help='最大迭代次数 (K-means方法)')
subparser.add_argument('--kmeans-batch-size', type=int, default=1000,
help='批处理大小 (K-means方法)')
# 模糊C均值专用参数
subparser.add_argument('--fcm-m', type=float, default=2.0,
help='模糊度参数 (模糊C均值方法默认2.0)')
subparser.add_argument('--fcm-error', type=float, default=0.005,
help='终止误差 (模糊C均值方法)')
subparser.add_argument('--fcm-maxiter', type=int, default=1000,
help='最大迭代次数 (模糊C均值方法)')
# GMM专用参数
subparser.add_argument('--gmm-covariance-type', choices=['full', 'tied', 'diag', 'spherical'], default='full',
help='协方差类型 (GMM方法)')
subparser.add_argument('--gmm-max-iter', type=int, default=100,
help='最大迭代次数 (GMM方法)')
# 层次聚类专用参数
subparser.add_argument('--hier-linkage', choices=['ward', 'complete', 'average', 'single'], default='ward',
help='连接方法 (层次聚类方法)')
subparser.add_argument('--hier-affinity', choices=['euclidean', 'l1', 'l2', 'manhattan', 'cosine'], default='euclidean',
help='距离度量 (层次聚类方法)')
# DBSCAN专用参数
subparser.add_argument('--dbscan-eps', type=float, default=0.5,
help='邻域半径 (DBSCAN方法)')
subparser.add_argument('--dbscan-min-samples', type=int, default=5,
help='核心点的最小样本数 (DBSCAN方法)')
subparser.add_argument('--dbscan-algorithm', choices=['auto', 'ball_tree', 'kd_tree', 'brute'], default='auto',
help='算法 (DBSCAN方法)')
# 谱聚类专用参数
subparser.add_argument('--spectral-affinity', choices=['rbf', 'nearest_neighbors'], default='rbf',
help='相似度矩阵构建方法 (谱聚类方法)')
subparser.add_argument('--spectral-gamma', type=float, default=1.0,
help='RBF核函数的gamma参数 (谱聚类方法)')
# 子空间聚类专用参数
subparser.add_argument('--subspace-n-subspaces', type=int,
help='子空间数量 (子空间聚类方法,默认等于聚类数)')
# 集成聚类专用参数
subparser.add_argument('--ensemble-voting', choices=['hard', 'soft'], default='hard',
help='投票方式 (集成聚类方法)')
def build_config(self, args):
from cluster_method.cluster import ClusteringConfig
# 如果是批量模式,使用默认配置
if hasattr(args, 'batch') and args.batch:
return ClusteringConfig(
input_path=args.input,
method='batch', # 特殊标记,表示批量处理
n_clusters=args.n_clusters,
output_dir=args.output_dir
)
# 构建基础配置
config_kwargs = {
'input_path': args.input,
'method': args.method,
'n_clusters': args.n_clusters,
'output_dir': args.output_dir,
'use_scaling': getattr(args, 'use_scaling', True),
'scaler_type': getattr(args, 'scaler_type', 'standard'),
'use_pca': getattr(args, 'use_pca', False),
'pca_components': getattr(args, 'pca_components', None)
}
# 构建方法特定参数
method_params = {}
if args.method == 'kmeans':
method_params['init'] = getattr(args, 'kmeans_init', 'k-means++')
method_params['n_init'] = getattr(args, 'kmeans_n_init', 10)
method_params['max_iter'] = getattr(args, 'kmeans_max_iter', 300)
method_params['batch_size'] = getattr(args, 'kmeans_batch_size', 1000)
elif args.method == 'fuzzy-cmeans':
method_params['fuzziness'] = getattr(args, 'fcm_m', 2.0)
method_params['error'] = getattr(args, 'fcm_error', 0.005)
method_params['max_iter'] = getattr(args, 'fcm_maxiter', 1000)
elif args.method == 'gmm':
method_params['covariance_type'] = getattr(args, 'gmm_covariance_type', 'full')
method_params['max_iter'] = getattr(args, 'gmm_max_iter', 200)
method_params['n_init_attempts'] = getattr(args, 'gmm_n_init_attempts', 10)
elif args.method == 'hierarchical':
method_params['linkage'] = getattr(args, 'hier_linkage', 'ward')
method_params['affinity'] = getattr(args, 'hier_affinity', 'euclidean')
elif args.method == 'dbscan':
method_params['eps_percentile'] = getattr(args, 'dbscan_eps_percentile', 50)
method_params['min_samples_factor'] = getattr(args, 'dbscan_min_samples_factor', 0.1)
method_params['n_neighbors'] = getattr(args, 'dbscan_n_neighbors', 20)
elif args.method == 'spectral':
method_params['affinity'] = getattr(args, 'spectral_affinity', 'nearest_neighbors')
method_params['n_neighbors'] = getattr(args, 'spectral_n_neighbors', 10)
method_params['large_dataset_threshold'] = getattr(args, 'spectral_large_dataset_threshold', 2000)
elif args.method == 'subspace':
method_params['n_components_factor'] = getattr(args, 'subspace_n_components_factor', 0.33)
method_params['max_iter'] = getattr(args, 'subspace_max_iter', 300)
elif args.method == 'ensemble':
method_params['voting'] = getattr(args, 'ensemble_voting', 'hard')
if method_params:
config_kwargs['method_params'] = method_params
return ClusteringConfig(**config_kwargs)
class FeatureSelectionArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--method', '-m', default='Spa',
choices=['Spa', 'Cars', 'Uve', 'GA', 'ReliefF', 'SiPLS', 'Lars', 'RandomFrog'],
help='特征选择方法')
subparser.add_argument('--label-column', '-l', required=True,
help='标签列名')
subparser.add_argument('--n-features', type=int, default=20,
help='选择的特征数量')
subparser.add_argument('--spectral-columns',
help='光谱数据列范围,如 "1:10""auto" (不指定则自动检测)')
# 输出和可视化参数
subparser.add_argument('--output-csv', action='store_true', default=True,
help='是否输出CSV结果文件')
subparser.add_argument('--save-plots', action='store_true', default=True,
help='是否保存可视化图表')
subparser.add_argument('--plot-prefix', default='feature_selection',
help='图表文件名前缀')
# SPA专用参数
subparser.add_argument('--spa-min-vars', type=int, default=2,
help='SPA最小变量数')
subparser.add_argument('--spa-max-vars', type=int, default=50,
help='SPA最大变量数')
subparser.add_argument('--spa-autoscaling', type=int, choices=[0, 1], default=1,
help='SPA自动缩放 (0=关闭, 1=开启)')
# CARS专用参数
subparser.add_argument('--cars-n', type=int, default=50,
help='CARS蒙特卡洛采样次数')
subparser.add_argument('--cars-cv', type=int, default=10,
help='CARS交叉验证折数')
# UVE专用参数
subparser.add_argument('--uve-ncomp', type=int, default=20,
help='UVE PLS成分数')
subparser.add_argument('--uve-cv', type=int, default=5,
help='UVE交叉验证折数')
# GA专用参数
subparser.add_argument('--ga-population-size', type=int, default=10,
help='GA种群大小')
subparser.add_argument('--ga-n-generations', type=int, default=50,
help='GA进化代数')
subparser.add_argument('--ga-crossover-rate', type=float, default=0.8,
help='GA交叉概率')
subparser.add_argument('--ga-mutation-rate', type=float, default=0.1,
help='GA变异概率')
# Relief-F专用参数
subparser.add_argument('--relief-k', type=int, default=10,
help='Relief-F最近邻数k')
subparser.add_argument('--relief-sample-size', type=int,
help='Relief-F采样大小 (默认使用全部样本)')
# SiPLS专用参数
subparser.add_argument('--sipls-interval-width', type=int, default=10,
help='SiPLS区间宽度')
subparser.add_argument('--sipls-cv', type=int, default=5,
help='SiPLS交叉验证折数')
# LAR专用参数
subparser.add_argument('--lar-cv', type=int, default=5,
help='LAR交叉验证折数')
subparser.add_argument('--lar-max-features', type=int,
help='LAR最大特征数 (默认无限制)')
# RandomFrog专用参数
subparser.add_argument('--random-frog-n-frogs', type=int, default=50,
help='RandomFrog青蛙数量')
subparser.add_argument('--random-frog-n-memeplexes', type=int, default=5,
help='RandomFrog模因数量')
subparser.add_argument('--random-frog-n-evolution-steps', type=int, default=10,
help='RandomFrog进化步数')
subparser.add_argument('--random-frog-n-shuffle-iterations', type=int, default=10,
help='RandomFrog重排迭代次数')
subparser.add_argument('--random-frog-cv', type=int, default=5,
help='RandomFrog交叉验证折数')
def _parse_spectral_columns(self, csv_file_path: str, spectral_columns_arg: str, label_column: str) -> list:
"""
解析光谱列参数,返回列名列表
Args:
csv_file_path: CSV文件路径
spectral_columns_arg: 光谱列参数字符串,如 "1:165""auto"
label_column: 标签列名
Returns:
光谱列名列表
"""
import pandas as pd
# 读取CSV文件的列名
df = pd.read_csv(csv_file_path, nrows=0)
all_columns = df.columns.tolist()
# 找到标签列的索引
if label_column not in all_columns:
raise ValueError(f"标签列 '{label_column}' 不存在于CSV文件中")
label_index = all_columns.index(label_column)
# 获取除标签列外的所有列
spectral_candidates = [col for i, col in enumerate(all_columns) if i != label_index]
if spectral_columns_arg is None or spectral_columns_arg.lower() == 'auto':
# 自动检测:使用除标签列外的所有列
return spectral_candidates
# 解析范围字符串
try:
columns = []
# 分割多个范围(用逗号分隔)
for part in spectral_columns_arg.split(','):
part = part.strip()
if ':' in part:
# 范围选择,如 "1:165"
start, end = part.split(':')
start = int(start.strip())
end = int(end.strip())
# 转换为从标签列开始的索引
# 注意用户指定的索引是从1开始的第一列光谱数据为1
if start < 1 or end < 1:
raise ValueError(f"列范围 {start}:{end} 无效必须从1开始")
# 检查范围是否有效
if end > len(spectral_candidates):
end = len(spectral_candidates)
if start > len(spectral_candidates):
raise ValueError(f"起始列 {start} 超出可用光谱列范围 [1, {len(spectral_candidates)}]")
# 获取指定范围的列名
selected_columns = spectral_candidates[start-1:end] # Python索引从0开始
columns.extend(selected_columns)
else:
# 单个索引
idx = int(part.strip())
if idx < 1 or idx > len(spectral_candidates):
raise ValueError(f"列索引 {idx} 超出范围 [1, {len(spectral_candidates)}]")
columns.append(spectral_candidates[idx-1])
return list(dict.fromkeys(columns)) # 去重并保持顺序
except ValueError as e:
raise ValueError(f"解析光谱列参数 '{spectral_columns_arg}' 时出错: {e}")
def build_config(self, args):
from Feature_Selection_method.feture_select import FeatureSelectionConfig
import pandas as pd
# 构建方法参数
method_params = {}
if args.method == 'Spa':
method_params['m_min'] = getattr(args, 'spa_min_vars', 2)
method_params['m_max'] = getattr(args, 'spa_max_vars', 50)
method_params['autoscaling'] = getattr(args, 'spa_autoscaling', 1)
elif args.method == 'Cars':
method_params['N'] = getattr(args, 'cars_n', 50)
method_params['f'] = getattr(args, 'n_features', 20)
method_params['cv'] = getattr(args, 'cars_cv', 10)
elif args.method == 'Uve':
method_params['ncomp'] = getattr(args, 'uve_ncomp', 20)
method_params['cv'] = getattr(args, 'uve_cv', 5)
elif args.method == 'GA':
method_params['population_size'] = getattr(args, 'ga_population_size', 10)
method_params['n_generations'] = getattr(args, 'ga_n_generations', 50)
method_params['crossover_rate'] = getattr(args, 'ga_crossover_rate', 0.8)
method_params['mutation_rate'] = getattr(args, 'ga_mutation_rate', 0.1)
elif args.method == 'ReliefF':
method_params['k'] = getattr(args, 'relief_k', 10)
method_params['sample_size'] = getattr(args, 'relief_sample_size', None)
elif args.method == 'SiPLS':
method_params['interval_width'] = getattr(args, 'sipls_interval_width', 10)
method_params['cv'] = getattr(args, 'sipls_cv', 5)
elif args.method == 'Lars':
method_params['cv'] = getattr(args, 'lar_cv', 5)
method_params['max_features'] = getattr(args, 'lar_max_features', None)
elif args.method == 'RandomFrog':
# RandomFrog参数设置
method_params['n_frogs'] = getattr(args, 'random_frog_n_frogs', 50)
method_params['n_memeplexes'] = getattr(args, 'random_frog_n_memeplexes', 5)
method_params['n_evolution_steps'] = getattr(args, 'random_frog_n_evolution_steps', 10)
method_params['n_shuffle_iterations'] = getattr(args, 'random_frog_n_shuffle_iterations', 10)
method_params['cv'] = getattr(args, 'random_frog_cv', 5)
# 解析光谱列
spectral_columns = self._parse_spectral_columns(args.input, getattr(args, 'spectral_columns', None), args.label_column)
# 构建配置参数
config_kwargs = {
'csv_file_path': args.input,
'label_column': args.label_column,
'method': args.method,
'output_dir': args.output_dir,
'method_params': method_params,
'spectral_columns': spectral_columns,
'output_csv': getattr(args, 'output_csv', True),
'save_plots': getattr(args, 'save_plots', True),
'plot_name_prefix': getattr(args, 'plot_prefix', 'feature_selection')
}
return FeatureSelectionConfig(**config_kwargs)
class SpectralIndexArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--indices', '-idx', nargs='+',
help='要计算的光谱指数列表 (默认: 全部)')
subparser.add_argument('--batch', action='store_true',
help='计算所有指数')
subparser.add_argument('--png', action='store_true',
help='生成PNG可视化')
# 数据格式参数
subparser.add_argument('--data-format', choices=['csv', 'envi', 'auto'], default='auto',
help='数据文件格式')
subparser.add_argument('--label-column',
help='标签列名 (CSV格式时需要)')
subparser.add_argument('--spectral-start',
help='光谱数据起始列名或波段 (CSV格式时需要)')
subparser.add_argument('--spectral-end',
help='光谱数据结束列名或波段 (CSV格式时需要)')
# 输出参数
subparser.add_argument('--spectral-output-format', choices=['csv', 'excel', 'both'], default='both',
help='输出文件格式')
subparser.add_argument('--spectral-save-individual', action='store_true',
help='是否保存单个指数结果')
subparser.add_argument('--spectral-output-prefix', default='spectral_indices',
help='输出文件名前缀')
# 可视化参数
subparser.add_argument('--plot-histogram', action='store_true', default=True,
help='是否绘制直方图')
subparser.add_argument('--plot-boxplot', action='store_true', default=True,
help='是否绘制箱线图')
subparser.add_argument('--plot-correlation', action='store_true',
help='是否绘制相关性热力图')
subparser.add_argument('--plot-wavelength', action='store_true',
help='是否按波长绘制指数分布')
# 统计参数
subparser.add_argument('--calculate-stats', action='store_true', default=True,
help='是否计算统计信息')
subparser.add_argument('--stats-percentiles', nargs='+', type=float,
help='要计算的百分位数列表 (默认: 25, 50, 75)')
# 指数计算参数
subparser.add_argument('--index-csv', help='自定义光谱指数定义CSV文件路径')
subparser.add_argument('--formula-csv', help='自定义公式定义CSV文件路径')
def build_config(self, args):
from spectral_feature_method.spectral_index import SpectralIndexConfig
# 创建基础配置
config = SpectralIndexConfig.create_quick_analysis(
data_file_path=args.input,
indices_to_calculate=args.indices
)
# 更新数据配置
config.data.data_format = getattr(args, 'data_format', 'auto')
config.data.label_column = getattr(args, 'label_column', None)
config.data.spectral_start = getattr(args, 'spectral_start', None)
config.data.spectral_end = getattr(args, 'spectral_end', None)
# 更新输出配置
config.output.output_format = getattr(args, 'spectral_output_format', 'both')
config.output.save_individual_indices = getattr(args, 'spectral_save_individual', False)
config.output.output_prefix = getattr(args, 'spectral_output_prefix', 'spectral_indices')
# 更新可视化配置
config.output.plot_histogram = getattr(args, 'plot_histogram', True)
config.output.plot_boxplot = getattr(args, 'plot_boxplot', True)
config.output.plot_correlation = getattr(args, 'plot_correlation', False)
config.output.plot_wavelength = getattr(args, 'plot_wavelength', False)
# 更新统计配置
config.output.calculate_statistics = getattr(args, 'calculate_stats', True)
if hasattr(args, 'stats_percentiles'):
config.output.statistics_percentiles = args.stats_percentiles
# 更新指数配置
if hasattr(args, 'index_csv'):
config.indices.spectral_index_csv = args.index_csv
if hasattr(args, 'formula_csv'):
config.indices.formula_csv = args.formula_csv
return config
class PreprocessingArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--method', '-m', default='SS',
choices=['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave'],
help='预处理方法')
subparser.add_argument('--spectral-start-index', type=int, default=1,
help='CSV文件的谱段起始列索引 (从0开始默认: 1)')
subparser.add_argument('--handle-outliers', action='store_true',
help='处理异常值')
subparser.add_argument('--outlier-method', default='iqr',
choices=['iqr', 'isolation-forest', 'lof'],
help='异常值检测方法')
# CSV文件验证参数
subparser.add_argument('--validate-csv', action='store_true',
help='启用CSV文件验证')
subparser.add_argument('--min-rows', type=int, default=1,
help='CSV文件最小行数 (默认: 1)')
subparser.add_argument('--min-cols', type=int, default=2,
help='CSV文件最小列数 (默认: 2)')
subparser.add_argument('--check-missing-values', action='store_true',
help='检查缺失值')
subparser.add_argument('--check-data-types', action='store_true',
help='检查数据类型一致性')
subparser.add_argument('--wavelength-column',
help='波长列名 (用于验证波长数据)')
# MA方法参数
subparser.add_argument('--ma-window', type=int, default=11,
help='移动平均窗口大小 (默认: 11)')
# SG方法参数
subparser.add_argument('--sg-window', type=int, default=15,
help='Savitzky-Golay窗口大小 (默认: 15)')
subparser.add_argument('--sg-poly', type=int, default=2,
help='Savitzky-Golay多项式阶数 (默认: 2)')
def build_config(self, args):
from preprocessing_method.Preprocessing import PreprocessingConfig
return PreprocessingConfig(
input_path=args.input,
method=args.method,
spectral_start_index=getattr(args, 'spectral_start_index', 1),
handle_outliers=args.handle_outliers,
outlier_method=args.outlier_method,
output_dir=args.output_dir,
# CSV验证参数
validate_csv=getattr(args, 'validate_csv', False),
min_rows=getattr(args, 'min_rows', 1),
min_cols=getattr(args, 'min_cols', 2),
check_missing_values=getattr(args, 'check_missing_values', False),
check_data_types=getattr(args, 'check_data_types', False),
wavelength_column=getattr(args, 'wavelength_column', None),
# MA和SG方法参数
ma_window=getattr(args, 'ma_window', 11),
sg_window=getattr(args, 'sg_window', 15),
sg_poly=getattr(args, 'sg_poly', 2)
)
class ShapeFeatureArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
# 输入数据参数
subparser.add_argument('--input-type', default='dat',
choices=['dat', 'shp'],
help='输入类型 (默认: dat)')
subparser.add_argument('--hdr-file',
help='对应的HDR头文件路径仅在输入类型为dat时使用')
subparser.add_argument('--shape-band-index', type=int, default=0,
help='波段索引 (默认: 0)')
# 连通域处理参数
subparser.add_argument('--connectivity', type=int, choices=[4, 8], default=8,
help='连通性4或8 (默认: 8)')
subparser.add_argument('--min-area', type=int, default=10,
help='最小区域面积阈值 (默认: 10)')
# 分水岭算法参数
subparser.add_argument('--use-watershed', action='store_true',
help='使用分水岭算法分离相邻物体')
subparser.add_argument('--watershed-min-distance', type=int, default=10,
help='分水岭算法中局部最大值的最小距离 (默认: 10)')
# 输出设置参数
subparser.add_argument('--output-csv', action='store_true', default=True,
help='输出CSV文件 (默认: True)')
subparser.add_argument('--save-labeled-image', action='store_true',
help='保存标记后的图像')
def build_config(self, args):
from spatial_features_method.shape_feature import ShapeFeatureConfig
config = ShapeFeatureConfig()
config.input_type = args.input_type
config.dat_file_path = args.input if args.input_type == 'dat' else None
config.hdr_file_path = getattr(args, 'hdr_file', None)
config.shp_file_path = args.input if args.input_type == 'shp' else None
config.band_index = args.shape_band_index
config.connectivity = getattr(args, 'connectivity', 8)
config.min_area = args.min_area
config.use_watershed = args.use_watershed
config.watershed_min_distance = getattr(args, 'watershed_min_distance', 10)
config.output_dir = args.output_dir
config.output_csv = getattr(args, 'output_csv', True)
config.save_labeled_image = getattr(args, 'save_labeled_image', False)
return config
class GLCMArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
# GLCM参数
subparser.add_argument('--nbit', type=int, default=64,
help='灰度级数 (默认: 64)')
subparser.add_argument('--min-gray', type=int, default=0,
help='最小灰度值 (默认: 0)')
subparser.add_argument('--max-gray', type=int, default=255,
help='最大灰度值 (默认: 255)')
subparser.add_argument('--slide-window', type=int, default=7,
help='滑动窗口大小,必须为奇数 (默认: 7)')
# GLCM计算参数
subparser.add_argument('--step', type=int, nargs='+', default=[2],
help='步长距离列表 (默认: [2])')
subparser.add_argument('--angle', type=float, nargs='+', default=[0],
help='角度列表(弧度) (默认: [0])')
# 数据处理参数
subparser.add_argument('--band-index', '-b', type=int, default=25,
help='要处理的波段索引 (默认: 25)')
# 输出设置
subparser.add_argument('--save-dat', action='store_true', default=True,
help='保存为DAT文件 (默认: True)')
subparser.add_argument('--save-png', action='store_true',
help='保存为PNG文件')
def build_config(self, args):
from spatial_features_method.glcm import SpatialFeatureConfig
config = SpatialFeatureConfig()
config.nbit = args.nbit
config.mi = args.min_gray
config.ma = args.max_gray
config.slide_window = args.slide_window
config.step = args.step
config.angle = args.angle
config.band_index = args.band_index
config.image_path = args.input # 使用命令行输入的文件路径
config.output_dir = args.output_dir
config.save_dat = args.save_dat
config.save_png = args.save_png
return config
class ColorAnalysisArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--method', '-m', default='CIEDE2000',
choices=['CIE76', 'CIE94', 'CIEDE2000'],
help='色差计算方法')
subparser.add_argument('--standards-file', '-s', required=True,
help='标准颜色文件路径 (.csv)')
def build_config(self, args):
# 为 DeltaEApp 创建一个模拟的 args 对象
class MockArgs:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return MockArgs(
input=args.input,
standards_file=args.standards_file,
method=args.method,
output_dir=args.output_dir,
mode='image' # 默认使用图像模式
)
class FilterArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--filter-type', default='mean',
choices=['mean', 'median', 'gaussian', 'bilateral', 'opening', 'closing', 'erosion', 'dilation', 'gradient', 'tophat', 'blackhat'],
help='滤波类型')
subparser.add_argument('--filter-band-index', type=int, default=50,
help='波段索引')
subparser.add_argument('--kernel-size', '-k', type=int, default=3,
help='核大小')
# Smooth_filter 特有参数
subparser.add_argument('--sigma', '-s', type=float, default=1.0,
help='高斯标准差 (仅用于gaussian滤波)')
subparser.add_argument('--sigma-color', type=float, default=75.0,
help='颜色空间标准差 (仅用于bilateral滤波)')
subparser.add_argument('--sigma-space', type=float, default=75.0,
help='坐标空间标准差 (仅用于bilateral滤波)')
# morphological_fliter 特有参数
subparser.add_argument('--se-shape', default='disk',
choices=['disk', 'square', 'rectangle', 'diamond'],
help='结构元素形状 (仅用于形态学操作)')
subparser.add_argument('--se-size', type=int, default=3,
help='结构元素大小 (仅用于形态学操作)')
subparser.add_argument('--format-type', default='auto',
choices=['auto', 'bil', 'bip', 'bsq', 'dat'],
help='图像格式 (仅用于形态学操作)')
def build_config(self, args):
# 根据滤波类型选择合适的配置类
if args.filter_type in ['mean', 'median', 'gaussian', 'bilateral']:
from fliter_method.Smooth_filter import SmoothFilterConfig
return SmoothFilterConfig(
input_path=args.input,
filter_type=args.filter_type,
band_index=args.filter_band_index,
kernel_size=args.kernel_size,
output_dir=args.output_dir,
sigma=getattr(args, 'sigma', 1.0),
sigma_color=getattr(args, 'sigma_color', 75.0),
sigma_space=getattr(args, 'sigma_space', 75.0)
)
else:
from fliter_method.morphological_fliter import MorphologicalFilterConfig
return MorphologicalFilterConfig(
input_path=args.input,
operation=args.filter_type,
band_index=args.filter_band_index,
kernel_size=args.kernel_size,
output_dir=args.output_dir,
se_shape=getattr(args, 'se_shape', 'disk'),
se_size=getattr(args, 'se_size', 3),
format_type=getattr(args, 'format_type', 'auto')
)
class RegressionArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
# 数据配置参数
subparser.add_argument('--label-column', '-l', required=True,
help='标签列名')
subparser.add_argument('--test-size', type=float, default=0.2,
help='测试集比例 (默认: 0.2)')
subparser.add_argument('--random-state', type=int, default=42,
help='随机种子 (默认: 42)')
subparser.add_argument('--scale-method', choices=['standard', 'minmax'], default='standard',
help='数据标准化方法 (默认: standard)')
# 模型配置参数
subparser.add_argument('--models', '-m', nargs='+',
help='回归模型列表 (默认: 全部)')
subparser.add_argument('--tune-params', action='store_true',
help='启用超参数调优')
subparser.add_argument('--tuning-method', choices=['grid', 'random'], default='grid',
help='调优方法 (默认: grid)')
subparser.add_argument('--cv-folds', type=int, default=5,
help='交叉验证折数 (默认: 5)')
subparser.add_argument('--random-search-iter', type=int, default=20,
help='随机搜索迭代次数 (默认: 20)')
# 训练配置参数
subparser.add_argument('--epochs', type=int, default=100,
help='训练轮数 (默认: 100)')
subparser.add_argument('--batch-size', type=int, default=32,
help='批大小 (默认: 32)')
subparser.add_argument('--learning-rate', type=float, default=0.001,
help='学习率 (默认: 0.001)')
# 输出配置参数
subparser.add_argument('--save-models', action='store_true', default=True,
help='保存训练好的模型')
subparser.add_argument('--plot-results', action='store_true', default=True,
help='生成结果可视化图表')
subparser.add_argument('--save-dir', default='models',
help='模型保存目录 (默认: models)')
subparser.add_argument('--plot-dir', default='plots',
help='图表保存目录 (默认: plots)')
def build_config(self, args):
from rgression_method.regression import RegressionConfig
# 创建基础配置
config = RegressionConfig.create_default(
csv_path=args.input,
label_column=args.label_column
)
# 设置模型名称
if getattr(args, 'models', None):
config.models.model_names = args.models
# 根据是否启用调参设置其他参数
if getattr(args, 'tune_params', False):
# 启用调参的配置
config.models.tune_hyperparams = True
config.models.tuning_method = getattr(args, 'tuning_method', 'grid')
config.models.cv_folds = getattr(args, 'cv_folds', 5)
config.models.random_search_iter = getattr(args, 'random_search_iter', 20)
config.training.epochs = getattr(args, 'epochs', 100)
config.training.batch_size = getattr(args, 'batch_size', 32)
config.training.learning_rate = getattr(args, 'learning_rate', 0.001)
else:
# 快速分析配置
config.models.tune_hyperparams = False
# 设置通用参数
config.data.test_size = getattr(args, 'test_size', 0.2)
config.data.random_state = getattr(args, 'random_state', 42)
config.data.scale_method = getattr(args, 'scale_method', 'standard')
config.output.save_models = getattr(args, 'save_models', True)
config.output.plot_results = getattr(args, 'plot_results', True)
config.output.save_dir = getattr(args, 'save_dir', 'models')
config.output.plot_dir = getattr(args, 'plot_dir', 'plots')
return config
class RegressionPredictionArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
# 必需参数
subparser.add_argument('--model-path', '-m', required=True,
help='模型文件路径(单个文件或目录)')
# 可选参数
subparser.add_argument('--mask-path',
help='遮罩文件路径')
subparser.add_argument('--use-mask', action='store_true', default=True,
help='是否使用遮罩 (默认: True)')
subparser.add_argument('--batch-mode', action='store_true',
help='批量处理模式')
subparser.add_argument('--colormap', default='viridis',
choices=['viridis', 'plasma', 'inferno', 'magma', 'cividis',
'Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn'],
help='颜色映射 (默认: viridis)')
subparser.add_argument('--dpi', type=int, default=300,
help='输出图像DPI (默认: 300)')
subparser.add_argument('--save-individual', action='store_true', default=True,
help='保存单个模型预测结果')
def build_config(self, args):
from rgression_method.regression_predict import PredictionConfig
return PredictionConfig(
image_path=args.input,
mask_path=getattr(args, 'mask_path', None),
model_path=args.model_path,
output_dir=args.output_dir,
use_mask=getattr(args, 'use_mask', True),
batch_mode=getattr(args, 'batch_mode', False),
colormap=getattr(args, 'colormap', 'viridis'),
dpi=getattr(args, 'dpi', 300),
save_individual=getattr(args, 'save_individual', True)
)
class DeltaEArgsBuilder(ArgsBuilder):
def add_common_args(self, subparser):
"""DeltaE使用自己的参数命名"""
pass
def add_module_specific_args(self, subparser):
# DeltaE 特定的参数命名
subparser.add_argument('--mode', required=True, choices=['image', 'pairwise'],
help='计算模式: image (图像vs标准色) 或 pairwise (两两比较)')
subparser.add_argument('--input', required=True, help='输入文件路径 (ENVI格式LAB图像 或 CSV文件)')
subparser.add_argument('--standards', required=True, help='标准颜色CSV文件路径')
subparser.add_argument('--output-dir', default='./results', help='输出目录路径')
subparser.add_argument('--output-file', help='输出文件名 (可选,如果不指定则使用默认名称)')
subparser.add_argument('--create-heatmap', action='store_true', help='生成热图')
subparser.add_argument('--create-histogram', action='store_true', help='生成直方图')
subparser.add_argument('--method', default='CIEDE2000',
choices=['CIE76', 'CIE94', 'CIEDE2000'],
help='色差计算方法 (默认: CIEDE2000)')
subparser.add_argument('--standards-range', help='标准色范围 (e.g.: "0,1,2""0-5""all")')
subparser.add_argument('--reference-range', help='参考色范围 (pairwise模式)')
subparser.add_argument('--target-range', help='目标色范围 (pairwise模式)')
subparser.add_argument('--no-progress', action='store_true', help='禁用进度条显示')
def build_config(self, args):
# DeltaE直接使用args对象不需要转换
return args
class SpectralToColorArgsBuilder(ArgsBuilder):
def add_common_args(self, subparser):
"""SpectralToColor添加通用参数除了input因为它使用自己的input定义"""
# 添加除了 --input 之外的通用参数
subparser.add_argument('--output-dir', default='./results',
help='输出目录路径 (默认: ./results)')
subparser.add_argument('--output-prefix', default='result',
help='输出文件前缀 (默认: result)')
subparser.add_argument('--output-file',
help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)')
def add_module_specific_args(self, subparser):
# SpectralToColor特定的参数与原始spectral2cie2.py一致
subparser.add_argument('--input', required=True, help='输入文件路径 (.hdr高光谱图像 或 .csv文件)')
subparser.add_argument('-c', '--color_space', default='Lab',
choices=['XYZ', 'xyY', 'Lab', 'LCH'],
help='输出颜色空间 (默认: Lab)')
subparser.add_argument('-o', '--output', required=True,
help='输出文件路径')
subparser.add_argument('-f', '--format', default='dat', choices=['csv', 'dat'],
help='输出格式 (默认: dat)')
subparser.add_argument('-i', '--illuminant', default='D65',
choices=['D65', 'D50', 'A', 'F2', 'F7', 'F11'],
help='标准光源 (默认: D65)')
subparser.add_argument('-b', '--observer', default='',
choices=['', '10°'],
help='观察者 (默认: 2°)')
subparser.add_argument('-w', '--wavelength_col',
help='CSV文件的波长列名 (默认: wavelength)')
subparser.add_argument('--no_normalize_xyz', action='store_true',
help='不归一化XYZ值')
subparser.add_argument('--use_numpy', action='store_true',
help='使用NumPy实现而不是colour库')
def build_config(self, args):
# SpectralToColor直接使用args对象但需要添加output_dir属性以保持兼容性
# 从输出路径中提取目录部分作为output_dir
if hasattr(args, 'output') and args.output:
import os
args.output_dir = os.path.dirname(args.output) or './results'
else:
args.output_dir = './results'
return args
class XYZ2RGBArgsBuilder(ArgsBuilder):
def add_common_args(self, subparser):
"""XYZ2RGB添加通用参数除了input因为它使用自己的input定义"""
# 添加除了 --input 之外的通用参数
subparser.add_argument('--output-dir', default='./results',
help='输出目录路径 (默认: ./results)')
subparser.add_argument('--output-prefix', default='result',
help='输出文件前缀 (默认: result)')
subparser.add_argument('--output-file',
help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)')
def add_module_specific_args(self, subparser):
# XYZ2RGB特定的参数与原始XYZ2RGB.py一致
subparser.add_argument('--input', required=True,
help='输入XYZ图像文件路径')
subparser.add_argument('--output', required=True,
help='输出RGB图像文件路径')
subparser.add_argument('--rgb-space', default='sRGB',
choices=['sRGB', 'Adobe RGB (1998)', 'DCI-P3', 'ITU-R BT.709',
'ITU-R BT.2020', 'ACES2065-1', 'ACEScg', 'ProPhoto RGB',
'Apple RGB', 'PAL/SECAM', 'NTSC (1953)'],
help='RGB颜色空间 (默认: sRGB)')
subparser.add_argument('--gamma', default='sRGB',
choices=['none', 'sRGB', 'BT.709', 'gamma_2.2', 'gamma_1.8',
'gamma_2.4', 'L*', 'BT.1886', 'ST 2084', 'HLG', 'log'],
help='Gamma校正方法 (默认: sRGB)')
subparser.add_argument('--output-dtype', default='uint8',
choices=['uint8', 'uint10', 'uint12', 'uint16', 'int8',
'int16', 'float16', 'float32', 'float64'],
help='输出数据类型 (默认: uint8)')
def build_config(self, args):
# XYZ2RGB直接使用args对象但需要添加output_dir属性以保持兼容性
# 从输出路径中提取目录部分作为output_dir
if hasattr(args, 'output') and args.output:
import os
args.output_dir = os.path.dirname(args.output) or './results'
else:
args.output_dir = './results'
return args
class SupervisedClassificationArgsBuilder(ArgsBuilder):
def add_module_specific_args(self, subparser):
subparser.add_argument('--xml_file', required=True,
help='ENVI ROI XML文件路径')
subparser.add_argument('--method', '-m', default='all',
choices=['all', 'euclidean', 'cosine', 'correlation',
'information_divergence', 'jm_distance', 'sid_sa'],
help='分类距离度量方法 (默认: all使用所有方法)')
subparser.add_argument('--distance-params', '-p', type=str, default=None,
help='距离度量方法的超参数JSON格式字符串例如: {"jm_distance": {"alpha": 0.7}}')
subparser.add_argument('--visualize', '-v', action='store_true',
help='是否生成可视化结果')
def build_config(self, args):
import json
# 解析距离参数
distance_params = None
if getattr(args, 'distance_params', None):
try:
distance_params = json.loads(args.distance_params)
except json.JSONDecodeError:
print(f"警告: 无法解析距离参数 '{args.distance_params}',使用默认参数")
distance_params = None
# 返回配置字典,包含所有必要参数
config = {
'input_file': args.input,
'xml_file': args.xml_file, # 直接访问属性,不使用 getattr
'output_dir': args.output_dir,
'method': getattr(args, 'method', 'all'),
'distance_params': distance_params,
'visualize': getattr(args, 'visualize', False)
}
return config
class PROSAILGUIArgsBuilder(ArgsBuilder):
def add_common_args(self, subparser):
"""PROSAIL GUI 不需要通用参数,因为它是交互式的"""
pass
def add_module_specific_args(self, subparser):
# PROSAIL GUI 通常不需要额外的参数,因为它是交互式的
# 但可以添加一些选项
subparser.add_argument('--no-gui', action='store_true',
help='以命令行模式运行 (不启动GUI)')
subparser.add_argument('--save-default', action='store_true',
help='保存默认参数到文件')
def build_config(self, args):
# PROSAIL GUI 不需要复杂的配置,但需要添加一些属性以保持兼容性
# 给args对象添加output_dir属性避免后续访问错误
if not hasattr(args, 'output_dir'):
args.output_dir = './results'
return args
# 注册表
REGISTRY: Dict[str, Dict[str, Any]] = {
# 降维分析
"dim-reduction": {
"info": ModuleInfo(
category="降维分析",
module_path="Dimensionality_Reduction_method.dimensionality_reduction",
callable_name="HyperspectralDimReduction",
config_class="DimensionalityReductionConfig",
description="高光谱数据降维分析"
),
"args_builder": DimensionalityReductionArgsBuilder(),
"call_sequence": ["__init__", "load_data", "apply_dim_reduction", "save_results"]
},
# 图像分割
"segmentation": {
"info": ModuleInfo(
category="图像分割",
module_path="segment_method.threshold_Segment",
callable_name="ThresholdSegmenter",
config_class="ThresholdSegmentationConfig",
description="阈值图像分割"
),
"args_builder": SegmentationArgsBuilder(),
"call_sequence": ["__init__", "load_data", "apply_segmentation", "save_results"]
},
# 边缘检测
"edge-detection": {
"info": ModuleInfo(
category="边缘检测",
module_path="edge_detect_method.edge_detect",
callable_name="EdgeDetector",
config_class="EdgeDetectionConfig",
description="图像边缘检测"
),
"args_builder": EdgeDetectionArgsBuilder(),
"call_sequence": ["__init__", "load_data", "apply_edge_detection", "save_results"]
},
# 异常检测
"anomaly-detection": {
"info": ModuleInfo(
category="异常检测",
module_path="Anomaly_method.Covariance",
callable_name="CovarianceAnomalyDetector",
config_class="CovarianceAnomalyConfig",
description="异常检测分析"
),
"args_builder": AnomalyDetectionArgsBuilder(),
"call_sequence": ["__init__", "load_data", "run_analysis_from_config", "save_results"]
},
# 分类分析
"classification": {
"info": ModuleInfo(
category="分类分析",
module_path="classfication_method.classfication",
callable_name="ConfigurableHyperspectralClassifier",
config_class="ClassificationConfig",
description="高光谱图像分类",
requires_roi_file=True
),
"args_builder": ClassificationArgsBuilder(),
"call_sequence": ["__init__", "run_classification"]
},
# 聚类分析
"clustering": {
"info": ModuleInfo(
category="聚类分析",
module_path="cluster_method.cluster",
callable_name="ClusteringManager",
config_class="ClusteringConfig",
description="无监督聚类分析"
),
"args_builder": ClusteringArgsBuilder(),
"call_sequence": ["__init__", "fit_predict", "save_clustering_results"]
},
# 监督分类
"supervised-classification": {
"info": ModuleInfo(
category="分类分析",
module_path="supervize_cluster_method.supervize_cluster",
callable_name="run_hsi_classification",
config_class=None, # 这个任务直接使用函数调用
description="高光谱图像监督分类",
requires_roi_file=False, # 使用XML文件而不是ROI文件
requires_data_file=True
),
"args_builder": SupervisedClassificationArgsBuilder(),
"call_sequence": ["call_function"] # 直接调用函数
},
# 特征选择
"feature-selection": {
"info": ModuleInfo(
category="特征选择",
module_path="Feature_Selection_method.feture_select",
callable_name="SpectrumFeatureSelector",
config_class="FeatureSelectionConfig",
description="光谱特征选择",
requires_data_file=False
),
"args_builder": FeatureSelectionArgsBuilder(),
"call_sequence": ["__init__", "load_csv_data", "select_features"]
},
# 光谱指数
"spectral-index": {
"info": ModuleInfo(
category="光谱特征",
module_path="spectral_feature_method.spectral_index",
callable_name="HyperspectralIndexCalculator",
config_class="SpectralIndexConfig",
description="光谱指数计算"
),
"args_builder": SpectralIndexArgsBuilder(),
"call_sequence": ["__init__", "load_hyperspectral_data", "batch_calculate_and_visualize", "save_results"]
},
# 数据预处理
"preprocessing": {
"info": ModuleInfo(
category="预处理",
module_path="preprocessing_method.Preprocessing",
callable_name="HyperspectralPreprocessor",
config_class="PreprocessingConfig",
description="数据预处理"
),
"args_builder": PreprocessingArgsBuilder(),
"call_sequence": ["__init__", "load_data", "preprocess", "save_data"]
},
# 形状特征
"shape-features": {
"info": ModuleInfo(
category="空间特征",
module_path="spatial_features_method.shape_feature",
callable_name="analyze_shape_features",
description="形状特征分析",
requires_data_file=False
),
"args_builder": ShapeFeatureArgsBuilder(),
"call_sequence": ["call_function"]
},
# GLCM纹理特征提取
"glcm": {
"info": ModuleInfo(
category="空间特征",
module_path="spatial_features_method.glcm",
callable_name="run_glcm_analysis",
description="GLCM纹理特征提取",
requires_data_file=True
),
"args_builder": GLCMArgsBuilder(),
"call_sequence": ["call_function"]
},
# 颜色分析
"color-analysis": {
"info": ModuleInfo(
category="颜色分析",
module_path="color_method.DeltaE",
callable_name="DeltaEApp",
description="色差计算分析"
),
"args_builder": ColorAnalysisArgsBuilder(),
"call_sequence": ["run"]
},
# 图像滤波
"filtering": {
"info": ModuleInfo(
category="滤波",
module_path="fliter_method.Smooth_filter",
callable_name="HyperspectralImageFilter",
config_class="SmoothFilterConfig",
description="图像滤波处理"
),
"args_builder": FilterArgsBuilder(),
"call_sequence": ["process_image"]
},
# 回归分析
"regression": {
"info": ModuleInfo(
category="回归分析",
module_path="rgression_method.regression",
callable_name="RegressionAnalyzer",
config_class="RegressionConfig",
description="回归模型分析",
requires_data_file=False
),
"args_builder": RegressionArgsBuilder(),
"call_sequence": ["__init__", "run_analysis_from_config"]
},
# 回归预测
"regression-prediction": {
"info": ModuleInfo(
category="回归预测",
module_path="rgression_method.regression_predict",
callable_name="RegressionPredictor",
config_class="PredictionConfig",
description="使用训练好的回归模型进行高光谱图像预测",
requires_data_file=True
),
"args_builder": RegressionPredictionArgsBuilder(),
"call_sequence": ["__init__", "run_prediction"]
},
# 色差计算
"delta-e": {
"info": ModuleInfo(
category="颜色分析",
module_path="color_method.DeltaE",
callable_name="DeltaEApp",
description="色差计算分析 (CIE76, CIE94, CIEDE2000)",
requires_data_file=True
),
"args_builder": DeltaEArgsBuilder(),
"call_sequence": ["run"]
},
# 光谱到颜色转换
"spectral-to-color": {
"info": ModuleInfo(
category="颜色分析",
module_path="color_method.spectral2cie2",
callable_name="SpectralToColorConverter",
description="光谱数据转换为CIE颜色空间 (XYZ, xyY, Lab, LCH)",
requires_data_file=True
),
"args_builder": SpectralToColorArgsBuilder(),
"call_sequence": ["process_file"]
},
# XYZ到RGB转换
"xyz-to-rgb": {
"info": ModuleInfo(
category="颜色分析",
module_path="color_method.XYZ2RGB",
callable_name="XYZ2RGBApp",
description="XYZ颜色空间转换为RGB",
requires_data_file=True
),
"args_builder": XYZ2RGBArgsBuilder(),
"call_sequence": ["run"]
},
# PROSAIL模拟器GUI
"prosail-gui": {
"info": ModuleInfo(
category="植被模拟",
module_path="prosail_method.prosail_gui",
callable_name="PROSAILSimulator",
description="PROSAIL植被光谱模拟器GUI",
requires_data_file=False
),
"args_builder": PROSAILGUIArgsBuilder(),
"call_sequence": ["run_gui"]
}
}
def get_module_info(task_name: str) -> Optional[ModuleInfo]:
"""获取模块信息"""
task_info = REGISTRY.get(task_name)
return task_info["info"] if task_info else None
def get_args_builder(task_name: str) -> Optional[ArgsBuilder]:
"""获取参数构建器"""
task_info = REGISTRY.get(task_name)
return task_info["args_builder"] if task_info else None
def get_call_sequence(task_name: str) -> Optional[List[str]]:
"""获取调用序列"""
task_info = REGISTRY.get(task_name)
return task_info.get("call_sequence")
def create_module_parser(subparsers, task_name: str):
"""为指定任务创建子解析器"""
task_info = REGISTRY.get(task_name)
if not task_info:
return None
info = task_info["info"]
args_builder = task_info["args_builder"]
# 创建子解析器
subparser = subparsers.add_parser(
task_name,
help=f"{info.category}: {info.description}",
description=info.description
)
# 添加通用参数
args_builder.add_common_args(subparser)
# 添加模块特定参数
args_builder.add_module_specific_args(subparser)
return subparser
def dynamic_import_module(module_path: str):
"""动态导入模块"""
try:
return importlib.import_module(module_path)
except ImportError as e:
raise ImportError(f"无法导入模块 {module_path}: {e}")
def get_callable_from_module(module, callable_name: str):
"""从模块获取可调用对象"""
try:
return getattr(module, callable_name)
except AttributeError as e:
raise AttributeError(f"模块中没有找到可调用对象 {callable_name}: {e}")
def execute_task(task_name: str, args):
"""执行任务"""
task_info = REGISTRY.get(task_name)
if not task_info:
raise ValueError(f"未知任务: {task_name}")
info = task_info["info"]
args_builder = task_info["args_builder"]
call_sequence = task_info["call_sequence"]
# 构建配置
config = args_builder.build_config(args)
# 构建输出文件名
if hasattr(args, 'output_file') and args.output_file:
output_filename = args.output_file
else:
# 基于输出目录和前缀构建文件名
import os
output_prefix = getattr(args, 'output_prefix', 'result')
output_filename = os.path.join(args.output_dir, f"{output_prefix}.dat")
# 动态导入模块和可调用对象
if task_name == "anomaly-detection":
# 异常检测需要根据方法动态选择类
if args.method == 'covariance':
module = dynamic_import_module("Anomaly_method.Covariance")
callable_obj = get_callable_from_module(module, "CovarianceAnomalyDetector")
elif args.method == 'one-class-svm':
module = dynamic_import_module("Anomaly_method.One_Class_SVM")
callable_obj = get_callable_from_module(module, "OneClassSVMAnomalyDetector")
elif args.method == 'rx':
module = dynamic_import_module("Anomaly_method.RX")
callable_obj = get_callable_from_module(module, "RXAnomalyDetector")
elif args.method == 'squared-loss-probability':
module = dynamic_import_module("Anomaly_method.squared_loss_probability")
callable_obj = get_callable_from_module(module, "SquaredLossAnomalyDetector")
else:
raise ValueError(f"不支持的异常检测方法: {args.method}")
elif task_name == "clustering":
# 聚类任务直接调用run_clustering函数
from cluster_method.cluster import run_clustering
# 构建参数
method_params = getattr(config, 'method_params', {}) if hasattr(config, 'method_params') else {}
# 调用聚类函数
result = run_clustering(
input_file=config.input_path,
n_clusters=config.n_clusters,
methods=[config.method],
output_dir=config.output_dir,
method_params=method_params,
random_state=42
)
return result
elif task_name == "delta-e":
# 色差计算任务
module = dynamic_import_module("color_method.DeltaE")
callable_obj = get_callable_from_module(module, "DeltaEApp")
app_instance = callable_obj()
result = app_instance.run(config) # 使用build_config返回的对象
return result
elif task_name == "spectral-to-color":
# 光谱到颜色转换任务 - 直接调用main函数
import sys
from color_method.spectral2cie2 import main as spectral_main
# 保存原始sys.argv
original_argv = sys.argv
# 构造命令行参数
sys.argv = ['spectral2cie2.py', '--input', config.input, '-o', config.output,
'-c', getattr(config, 'color_space', 'Lab'),
'-i', getattr(config, 'illuminant', 'D65'),
'-b', getattr(config, 'observer', '')]
if hasattr(config, 'wavelength_col') and config.wavelength_col:
sys.argv.extend(['-w', config.wavelength_col])
if getattr(config, 'no_normalize_xyz', False):
sys.argv.append('--no_normalize_xyz')
if getattr(config, 'use_numpy', False):
sys.argv.append('--use_numpy')
try:
result = spectral_main()
finally:
# 恢复原始sys.argv
sys.argv = original_argv
return result
elif task_name == "xyz-to-rgb":
# XYZ到RGB转换任务
module = dynamic_import_module("color_method.XYZ2RGB")
callable_obj = get_callable_from_module(module, "XYZ2RGBApp")
app_instance = callable_obj()
result = app_instance.run(config) # 使用build_config返回的对象
return result
elif task_name == "preprocessing":
# 预处理任务 - 需要特殊处理load_data调用
from preprocessing_method.Preprocessing import HyperspectralPreprocessor
instance = HyperspectralPreprocessor(config)
# 确定输出文件路径(基于输入文件类型和方法名称)
import os
from pathlib import Path
input_path = Path(config.input_path)
method_name = config.method.lower() # 方法名称转为小写用于文件名
if input_path.suffix.lower() == '.csv':
# CSV输入输出CSV
output_path = os.path.join(config.output_dir, f"{method_name}_{input_path.stem}.csv")
else:
# ENVI输入或其他格式输出ENVI
output_path = os.path.join(config.output_dir, f"{method_name}_{input_path.stem}.dat")
# 调用load_data时传递正确的参数
instance.load_data(config.input_path, config.spectral_start_index)
# 执行预处理
result = instance.preprocess(config.method, output_path)
return result
elif task_name == "filtering":
# 滤波任务 - 根据滤波类型选择不同的处理方式
if hasattr(config, 'filter_type'):
# Smooth_filter类型
from fliter_method.Smooth_filter import HyperspectralImageFilter
filter_obj = HyperspectralImageFilter()
# 构建参数字典
kwargs = {
'kernel_size': config.kernel_size,
}
if config.filter_type == 'gaussian':
kwargs['sigma'] = config.sigma
elif config.filter_type == 'bilateral':
kwargs['sigma_color'] = config.sigma_color
kwargs['sigma_space'] = config.sigma_space
# 生成输出路径
output_path = os.path.join(config.output_dir, f"{config.filter_type}_filtered")
result = filter_obj.process_image(
input_path=config.input_path,
output_path=output_path,
filter_type=config.filter_type,
band_index=config.band_index,
**kwargs
)
return result
elif task_name == "prosail-gui":
# PROSAIL GUI 任务 - 启动GUI应用
from prosail_method.prosail_gui import PROSAILSimulator, QApplication
import sys
if getattr(config, 'no_gui', False):
# 命令行模式 - 这里可以添加命令行版本的PROSAIL功能
print("PROSAIL GUI 命令行模式暂未实现,请使用 GUI 模式:")
print("python main.py prosail-gui")
return None
else:
# GUI模式
app = QApplication(sys.argv)
app.setStyle('Fusion') # Modern style
window = PROSAILSimulator()
window.show()
result = app.exec_()
return result
elif task_name == "morphological_fliter":
# Morphological_filter类型
from fliter_method.morphological_fliter import HyperspectralMorphologyProcessor
processor = HyperspectralMorphologyProcessor()
# 加载图像
data, header = processor.load_hyperspectral_image(config.input_path, config.format_type)
# 提取波段
band_data = processor.extract_band(config.band_index)
# 应用形态学操作
result = processor.apply_morphology_operation(
band_data, config.operation, config.se_shape, config.se_size
)
# 保存结果
output_path = os.path.join(config.output_dir, f"{config.operation}_result")
processor.save_as_envi(result, output_path, f"{config.operation.capitalize()} 处理结果 - 波段 {config.band_index}")
return result
else:
module = dynamic_import_module(info.module_path)
callable_obj = get_callable_from_module(module, info.callable_name)
# 检查是否是批量处理(针对降维分析)
is_batch = (task_name == "dim-reduction" and
hasattr(args, 'batch_methods') and args.batch_methods is not None)
if is_batch:
# 批量处理
instance = callable_obj(config)
instance.load_data()
result = instance.batch_process(args.batch_methods, args.batch_components, args.output_dir)
return result
# 执行调用序列
result = None
instance = None
for step in call_sequence:
if step == "__init__":
instance = callable_obj(config)
elif step == "call_function":
# 直接调用函数传递config对象作为参数
result = callable_obj(config)
elif hasattr(instance, step):
method = getattr(instance, step)
if step in ["load_data", "read_hyperspectral_data", "load_hyperspectral_data", "load_csv_data"]:
method()
elif step in ["apply_dim_reduction", "apply_segmentation", "apply_edge_detection", "detect_anomalies", "select_features", "batch_calculate_and_visualize", "preprocess"]:
result = method()
elif step == "run":
# 对于某些应用类直接调用run方法并传递参数
result = method(config)
elif step == "save_results":
# 根据任务类型正确调用 save_results
if task_name == "dim-reduction":
# result = (reduced_data, band_names)
method(result[0], result[1], output_filename, args.method)
elif task_name == "segmentation":
# result = (segmented_data, threshold)
method(result[0], result[1], output_filename, args.method)
elif task_name == "edge-detection":
# result = (edge_data, info)
method(result[0], result[1], output_filename, args.method)
elif task_name == "spectral-index":
# result = (results, fig), save_results 期望 (results, output_prefix)
method(result[0], args.output_prefix)
elif task_name == "anomaly-detection":
# save_results 只接受可选的 output_path 参数
method(output_filename)
else:
# 默认调用方式
method(result, output_filename, task_name)
elif step == "run_analysis_from_config":
result = method()
elif step == "calculate_differences":
result = method()
elif step == "apply_filter":
result = method(args.filter_type if hasattr(args, 'filter_type') else 'mean')
elif step == "save_envi":
method(output_filename, result)
else:
# 其他方法直接调用
if step == "parse_roi_xml" and hasattr(args, 'roi_file'):
method(args.roi_file)
elif step == "extract_training_samples":
result = method()
elif step == "train_model":
method(result[0], result[1])
elif step == "predict_image":
method()
elif step == "run_classification":
method()
else:
method()
return result