""" 高光谱分析工具包统一注册表 提供所有功能模块的注册信息和参数构建器 """ import importlib from typing import Dict, Any, Callable, List, Optional, Union from dataclasses import dataclass import argparse @dataclass class ModuleInfo: """模块信息类""" category: str module_path: str callable_name: str config_class: Optional[str] = None description: str = "" requires_data_file: bool = True requires_roi_file: bool = False output_types: List[str] = None def __post_init__(self): if self.output_types is None: self.output_types = ["file", "summary"] class ArgsBuilder: """参数构建器基类""" def add_common_args(self, subparser: argparse.ArgumentParser): """添加通用参数""" subparser.add_argument('--input', '-i', required=True, help='输入文件路径 (ENVI格式: .hdr/.dat 或 CSV)') subparser.add_argument('--output-dir', '-o', default='./results', help='输出目录路径 (默认: ./results)') subparser.add_argument('--output-prefix', default='result', help='输出文件前缀 (默认: result)') subparser.add_argument('--output-file', '-f', help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)') def add_module_specific_args(self, subparser: argparse.ArgumentParser): """子类实现:添加模块特定参数""" pass def build_config(self, args) -> Any: """子类实现:构建配置对象""" pass # 各模块的参数构建器 class DimensionalityReductionArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): # CSV文件参数 subparser.add_argument('--label-col', '-l', help='CSV文件的标签列名或索引') subparser.add_argument('--spectral-start', '-s', help='CSV文件的谱段起始列名或索引') subparser.add_argument('--spectral-end', '-e', help='CSV文件的谱段结束列名或索引') subparser.add_argument('--method', '-m', default='pca', choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'], help='降维方法') subparser.add_argument('--n-components', '-n', type=int, default=3, help='降维后的维度数') subparser.add_argument('--batch-methods', nargs='+', choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'], help='批量处理的降维方法列表') subparser.add_argument('--batch-components', nargs='+', type=int, help='对应批量方法的维度数列表') # 数据预处理参数 subparser.add_argument('--use-standardization', action='store_true', default=True, help='是否使用数据标准化') subparser.add_argument('--generate-plots', action='store_true', default=True, help='是否生成可视化图表') subparser.add_argument('--color-by-label', action='store_true', default=True, help='是否按标签着色散点图') # PCA专用参数 subparser.add_argument('--pca-whiten', action='store_true', help='是否白化数据 (PCA方法)') subparser.add_argument('--pca-svd-solver', choices=['auto', 'full', 'arpack', 'randomized'], default='auto', help='SVD求解器 (PCA方法)') # ICA专用参数 subparser.add_argument('--ica-algorithm', choices=['parallel', 'deflation'], default='parallel', help='ICA算法 (ICA方法)') subparser.add_argument('--ica-fun', choices=['logcosh', 'exp', 'cube'], default='logcosh', help='非线性函数 (ICA方法)') subparser.add_argument('--ica-max-iter', type=int, default=200, help='最大迭代次数 (ICA方法)') # FA专用参数 subparser.add_argument('--fa-max-iter', type=int, default=1000, help='最大迭代次数 (FA方法)') subparser.add_argument('--fa-tol', type=float, default=1e-2, help='收敛容差 (FA方法)') # MDS专用参数 subparser.add_argument('--mds-metric', action='store_true', default=True, help='是否使用度量MDS (MDS方法)') subparser.add_argument('--mds-max-iter', type=int, default=300, help='最大迭代次数 (MDS方法)') subparser.add_argument('--mds-n-init', type=int, default=4, help='初始化次数 (MDS方法)') # Isomap专用参数 subparser.add_argument('--isomap-n-neighbors', type=int, default=5, help='邻居数量 (Isomap方法)') subparser.add_argument('--isomap-eigen-solver', choices=['auto', 'arpack', 'dense'], default='auto', help='特征值求解器 (Isomap方法)') # LLE专用参数 subparser.add_argument('--lle-n-neighbors', type=int, default=5, help='邻居数量 (LLE方法)') subparser.add_argument('--lle-method', choices=['standard', 'hessian', 'modified', 'ltsa'], default='standard', help='LLE方法类型 (LLE方法)') # t-SNE专用参数 subparser.add_argument('--tsne-perplexity', type=float, default=30.0, help='困惑度 (t-SNE方法)') subparser.add_argument('--tsne-early-exaggeration', type=float, default=12.0, help='早期夸张系数 (t-SNE方法)') subparser.add_argument('--tsne-learning-rate', type=float, default=200.0, help='学习率 (t-SNE方法)') subparser.add_argument('--tsne-max-iter', type=int, default=1000, help='最大迭代次数 (t-SNE方法)') def build_config(self, args): from Dimensionality_Reduction_method.dimensionality_reduction import DimensionalityReductionConfig # 检查是否是批量处理 is_batch = hasattr(args, 'batch_methods') and args.batch_methods is not None # 构建方法参数 method_params = {} if not is_batch and hasattr(args, 'method'): if args.method == 'pca': if hasattr(args, 'pca_whiten'): method_params['whiten'] = args.pca_whiten if hasattr(args, 'pca_svd_solver'): method_params['svd_solver'] = args.pca_svd_solver elif args.method == 'ica': if hasattr(args, 'ica_algorithm'): method_params['algorithm'] = args.ica_algorithm if hasattr(args, 'ica_fun'): method_params['fun'] = args.ica_fun if hasattr(args, 'ica_max_iter'): method_params['max_iter'] = args.ica_max_iter elif args.method == 'fa': if hasattr(args, 'fa_max_iter'): method_params['max_iter'] = args.fa_max_iter if hasattr(args, 'fa_tol'): method_params['tol'] = args.fa_tol elif args.method == 'mds': if hasattr(args, 'mds_metric'): method_params['metric'] = args.mds_metric if hasattr(args, 'mds_max_iter'): method_params['max_iter'] = args.mds_max_iter if hasattr(args, 'mds_n_init'): method_params['n_init'] = args.mds_n_init elif args.method == 'isomap': if hasattr(args, 'isomap_n_neighbors'): method_params['n_neighbors'] = args.isomap_n_neighbors if hasattr(args, 'isomap_eigen_solver'): method_params['eigen_solver'] = args.isomap_eigen_solver elif args.method == 'lle': if hasattr(args, 'lle_n_neighbors'): method_params['n_neighbors'] = args.lle_n_neighbors if hasattr(args, 'lle_method'): method_params['method'] = args.lle_method elif args.method == 't-sne': if hasattr(args, 'tsne_perplexity'): method_params['perplexity'] = args.tsne_perplexity if hasattr(args, 'tsne_early_exaggeration'): method_params['early_exaggeration'] = args.tsne_early_exaggeration if hasattr(args, 'tsne_learning_rate'): method_params['learning_rate'] = args.tsne_learning_rate if hasattr(args, 'tsne_max_iter'): method_params['max_iter'] = args.tsne_max_iter # 构建配置参数 config_kwargs = { 'input_path': args.input, 'output_dir': args.output_dir, 'label_col': getattr(args, 'label_col', None), 'spectral_start': getattr(args, 'spectral_start', None), 'spectral_end': getattr(args, 'spectral_end', None), 'method_params': method_params, 'use_standardization': getattr(args, 'use_standardization', True), 'generate_plots': getattr(args, 'generate_plots', True), 'color_by_label': getattr(args, 'color_by_label', True), 'save_plots': getattr(args, 'save_plots', True) } if is_batch: # 批量处理 if not hasattr(args, 'batch_components') or not args.batch_components: batch_components = [3] * len(args.batch_methods) else: batch_components = args.batch_components if len(batch_components) != len(args.batch_methods): raise ValueError(f"批量方法数量 ({len(args.batch_methods)}) 与维度数量 ({len(batch_components)}) 不匹配") config_kwargs.update({ 'batch_methods': args.batch_methods, 'batch_components': batch_components }) else: # 单方法处理 config_kwargs.update({ 'method': args.method, 'n_components': args.n_components }) return DimensionalityReductionConfig(**config_kwargs) class SegmentationArgsBuilder(ArgsBuilder): def add_common_args(self, subparser): """重写通用参数,设置分割专用的默认输出目录""" subparser.add_argument('--input', '-i', required=True, help='输入文件路径 (ENVI格式: .hdr/.dat 或图像文件)') subparser.add_argument('--output-dir', '-o', default='./segmentation_results', help='输出目录路径 (默认: ./segmentation_results)') subparser.add_argument('--output-prefix', default='result', help='输出文件前缀 (默认: result)') def add_module_specific_args(self, subparser): # 基本分割参数 subparser.add_argument('--method', '-m', default='otsu', choices=['fixed', 'bimodal', 'iterative', 'otsu', 'isodata', 'adaptive'], help='分割方法 (默认: otsu)') subparser.add_argument('--band-index', '-b', type=int, default=0, help='使用的波段索引 (默认: 0)') # 阈值参数 subparser.add_argument('--threshold', '-t', type=float, help='固定阈值 (仅用于fixed方法)') # 批量处理参数 subparser.add_argument('--batch-methods', nargs='+', choices=['fixed', 'bimodal', 'iterative', 'otsu', 'isodata', 'adaptive'], help='批量处理的分割方法列表') subparser.add_argument('--batch-bands', type=int, nargs='+', help='批量处理的波段索引列表') # 自适应阈值参数 subparser.add_argument('--adaptive-block-size', type=int, default=11, help='自适应阈值块大小,必须为奇数 (默认: 11)') subparser.add_argument('--adaptive-c', type=float, default=2.0, help='自适应阈值常数 (默认: 2.0)') subparser.add_argument('--adaptive-method', choices=['gaussian', 'mean'], default='gaussian', help='自适应阈值方法 (默认: gaussian)') # 迭代法参数 subparser.add_argument('--max-iterations', type=int, default=100, help='迭代法最大迭代次数 (默认: 100)') subparser.add_argument('--convergence-threshold', type=float, default=0.01, help='迭代法收敛阈值 (默认: 0.01)') # 直方图双峰法参数 subparser.add_argument('--histogram-bins', type=int, default=256, help='直方图bin数量 (默认: 256)') subparser.add_argument('--peak-min-distance', type=int, default=10, help='峰值最小距离 (默认: 10)') # Otsu参数 subparser.add_argument('--otsu-normalize', action='store_true', default=True, help='Otsu方法是否归一化图像到0-255范围 (默认: True)') def build_config(self, args): from segment_method.threshold_Segment import ThresholdSegmentationConfig # 处理批量参数 batch_methods = getattr(args, 'batch_methods', None) batch_bands = getattr(args, 'batch_bands', None) if batch_methods is None: batch_methods = [args.method] if batch_bands is None: batch_bands = [args.band_index] return ThresholdSegmentationConfig( input_path=args.input, output_dir=args.output_dir, method=args.method, band_index=args.band_index, threshold_value=getattr(args, 'threshold', None), batch_methods=batch_methods, batch_bands=batch_bands, adaptive_block_size=getattr(args, 'adaptive_block_size', 11), adaptive_c=getattr(args, 'adaptive_c', 2.0), adaptive_method=getattr(args, 'adaptive_method', 'gaussian'), max_iterations=getattr(args, 'max_iterations', 100), convergence_threshold=getattr(args, 'convergence_threshold', 0.01), histogram_bins=getattr(args, 'histogram_bins', 256), peak_min_distance=getattr(args, 'peak_min_distance', 10), otsu_normalize=getattr(args, 'otsu_normalize', True) ) class EdgeDetectionArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): # 基本参数 subparser.add_argument('--method', '-m', default='canny', choices=['sobel', 'scharr', 'laplacian', 'log', 'canny'], help='边缘检测方法') subparser.add_argument('--edge-band-index', type=int, default=0, help='使用的波段索引') subparser.add_argument('--batch', action='store_true', help='批量运行多种方法') subparser.add_argument('--batch-bands', nargs='+', type=int, help='批量处理的波段索引列表') # Sobel和Scharr参数 subparser.add_argument('--sobel-dx', type=int, default=1, choices=[0, 1], help='Sobel x方向导数阶数,必须与sobel-dy之和等于1 (默认: 1)') subparser.add_argument('--sobel-dy', type=int, default=0, choices=[0, 1], help='Sobel y方向导数阶数,必须与sobel-dx之和等于1 (默认: 0)') subparser.add_argument('--sobel-ksize', type=int, default=-1, help='Sobel核大小,-1表示3x3 (默认: -1)') # Laplacian参数 subparser.add_argument('--laplacian-ksize', type=int, default=1, help='Laplacian核大小 (默认: 1)') subparser.add_argument('--laplacian-scale', type=float, default=1.0, help='Laplacian尺度参数 (默认: 1.0)') subparser.add_argument('--laplacian-delta', type=float, default=0.0, help='Laplacian delta参数 (默认: 0.0)') # LoG参数 subparser.add_argument('--log-sigma', type=float, default=1.0, help='LoG高斯标准差 (默认: 1.0)') subparser.add_argument('--log-threshold', type=float, default=0.0, help='LoG阈值,0表示自动阈值 (默认: 0.0)') subparser.add_argument('--log-mode', default='reflect', choices=['reflect', 'constant', 'nearest', 'mirror', 'wrap'], help='LoG边界模式 (默认: reflect)') # 梯度算法阈值参数 subparser.add_argument('--gradient-threshold', type=float, default=0.0, help='梯度阈值 (0-255),0表示不使用阈值 (默认: 0.0)') subparser.add_argument('--gradient-use-otsu', action='store_true', help='对梯度图使用Otsu自动阈值') # Canny参数 subparser.add_argument('--canny-min-threshold', type=int, default=100, help='Canny最小阈值 (默认: 100)') subparser.add_argument('--canny-max-threshold', type=int, default=200, help='Canny最大阈值 (默认: 200)') subparser.add_argument('--canny-aperture-size', type=int, default=3, choices=[3, 5, 7], help='Canny Sobel算子孔径大小 (默认: 3)') subparser.add_argument('--canny-l2gradient', action='store_true', help='Canny使用L2范数计算梯度') # 通用参数 subparser.add_argument('--normalize-output', action='store_true', help='将输出归一化到0-255范围') subparser.add_argument('--use-blur', action='store_true', help='预先进行高斯模糊') subparser.add_argument('--blur-kernel-size', type=int, default=3, help='模糊核大小,必须是奇数 (默认: 3)') subparser.add_argument('--blur-sigma', type=float, default=1.0, help='模糊标准差 (默认: 1.0)') def build_config(self, args): from edge_detect_method.edge_detect import EdgeDetectionConfig # 确定输出路径 if hasattr(args, 'output_file') and args.output_file: output_path = args.output_file output_dir = None else: output_path = None output_dir = args.output_dir # 构建批量处理参数 batch_methods = [args.method] batch_bands = [args.edge_band_index] if hasattr(args, 'batch') and args.batch: # 如果是批量模式,使用默认的多种方法 batch_methods = ['sobel', 'scharr', 'laplacian', 'log', 'canny'] if hasattr(args, 'batch_bands') and args.batch_bands: if len(args.batch_bands) == len(batch_methods): batch_bands = args.batch_bands else: batch_bands = args.batch_bands * (len(batch_methods) // len(args.batch_bands) + 1) batch_bands = batch_bands[:len(batch_methods)] else: batch_bands = [args.edge_band_index] * len(batch_methods) return EdgeDetectionConfig( input_path=args.input, method=args.method, band_index=args.edge_band_index, output_path=output_path, output_dir=output_dir, batch_methods=batch_methods, batch_bands=batch_bands, # Sobel参数 sobel_dx=getattr(args, 'sobel_dx', 1), sobel_dy=getattr(args, 'sobel_dy', 1), sobel_ksize=getattr(args, 'sobel_ksize', -1), # Laplacian参数 laplacian_ksize=getattr(args, 'laplacian_ksize', 1), laplacian_scale=getattr(args, 'laplacian_scale', 1.0), laplacian_delta=getattr(args, 'laplacian_delta', 0.0), # LoG参数 log_sigma=getattr(args, 'log_sigma', 1.0), log_threshold=getattr(args, 'log_threshold', 0.0), log_mode=getattr(args, 'log_mode', 'reflect'), # 梯度参数 gradient_threshold=getattr(args, 'gradient_threshold', 0.0), gradient_use_otsu=getattr(args, 'gradient_use_otsu', False), # Canny参数 canny_min_threshold=getattr(args, 'canny_min_threshold', 100), canny_max_threshold=getattr(args, 'canny_max_threshold', 200), canny_aperture_size=getattr(args, 'canny_aperture_size', 3), canny_l2gradient=getattr(args, 'canny_l2gradient', False), # 通用参数 normalize_output=getattr(args, 'normalize_output', False), use_blur=getattr(args, 'use_blur', False), blur_kernel_size=getattr(args, 'blur_kernel_size', 3), blur_sigma=getattr(args, 'blur_sigma', 1.0) ) class AnomalyDetectionArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--method', '-m', default='covariance', choices=['covariance', 'one-class-svm', 'rx', 'squared-loss-probability'], help='异常检测方法') # 通用参数(所有方法都可能用到) subparser.add_argument('--contamination', '-c', type=float, default=0.1, help='异常样本比例 (用于自动确定阈值)') # RX方法专用参数 subparser.add_argument('--background', help='背景数据路径 (RX方法custom模式时必需)') subparser.add_argument('--label-col', help='标签列名 (CSV文件)') subparser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)') subparser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)') subparser.add_argument('--background-model', choices=['global', 'local_row', 'local_square', 'custom'], default='global', help='背景模型类型 (RX方法,默认global)') subparser.add_argument('--window-size', type=int, default=5, help='局部窗口大小 (RX方法,默认5)') subparser.add_argument('--row-window', type=int, default=3, help='行窗口大小 (RX方法,默认3)') subparser.add_argument('--rx-threshold', type=float, help='异常检测阈值 (RX方法,不指定则只输出距离图)') subparser.add_argument('--robust-covariance', action='store_true', help='使用稳健协方差估计 (RX方法)') # One-Class SVM专用参数 subparser.add_argument('--ocsvm-kernel', choices=['linear', 'rbf', 'poly'], default='rbf', help='SVM核函数 (One-Class SVM方法,默认rbf)') subparser.add_argument('--ocsvm-nu', type=float, default=0.1, help='训练误差比例 (One-Class SVM方法,默认0.1)') subparser.add_argument('--ocsvm-gamma', type=float, help='核函数参数gamma (One-Class SVM方法,不指定则自动选择)') subparser.add_argument('--ocsvm-degree', type=int, default=3, help='多项式核的阶数 (One-Class SVM方法,默认3)') subparser.add_argument('--ocsvm-coef0', type=float, default=0.0, help='核函数常数项 (One-Class SVM方法,默认0.0)') subparser.add_argument('--ocsvm-tol', type=float, default=1e-3, help='停止准则公差 (One-Class SVM方法,默认1e-3)') subparser.add_argument('--ocsvm-shrinking', action='store_true', default=True, help='是否使用shrinking heuristic (One-Class SVM方法,默认True)') subparser.add_argument('--ocsvm-cache-size', type=float, default=200, help='核缓存大小(MB) (One-Class SVM方法,默认200)') subparser.add_argument('--ocsvm-max-iter', type=int, default=-1, help='最大迭代次数 (One-Class SVM方法,默认-1表示无限制)') subparser.add_argument('--ocsvm-use-scaling', action='store_true', default=True, help='是否使用标准化 (One-Class SVM方法,默认True)') subparser.add_argument('--ocsvm-use-pca', action='store_true', default=True, help='是否使用PCA降维 (One-Class SVM方法,默认True)') subparser.add_argument('--ocsvm-n-components', type=int, help='PCA降维维度 (One-Class SVM方法,默认自动选择)') subparser.add_argument('--ocsvm-use-grid-search', action='store_true', help='是否使用网格搜索调优参数 (One-Class SVM方法,默认False)') subparser.add_argument('--ocsvm-cv-folds', type=int, default=3, help='交叉验证折数 (One-Class SVM方法,默认3)') # Squared Loss Probability专用参数 subparser.add_argument('--reconstruction-method', choices=['linear', 'pca'], default='linear', help='重构方法 (Squared Loss方法,默认linear)') subparser.add_argument('--slp-n-components', type=int, help='PCA降维维度 (Squared Loss方法使用PCA时必需)') subparser.add_argument('--target-band', type=int, help='目标波段索引 (Squared Loss方法,默认使用最后一个波段)') subparser.add_argument('--slp-threshold', type=float, default=0.8, help='分类阈值 (Squared Loss方法,默认0.8)') subparser.add_argument('--probability-mode', action='store_true', default=True, help='使用概率模式 (Squared Loss方法,默认True)') subparser.add_argument('--slp-use-scaling', action='store_true', default=True, help='是否使用标准化 (Squared Loss方法,默认True)') subparser.add_argument('--scaler-type', choices=['standard', 'robust'], default='robust', help='标准化类型 (Squared Loss方法,默认robust)') subparser.add_argument('--fit-intercept', action='store_true', default=True, help='是否拟合截距 (Squared Loss方法,默认True)') def build_config(self, args): if args.method == 'covariance': from Anomaly_method.Covariance import CovarianceAnomalyConfig config_kwargs = { 'input_path': args.input, 'output_dir': args.output_dir, 'contamination': getattr(args, 'contamination', 0.1) } return CovarianceAnomalyConfig(**config_kwargs) elif args.method == 'one-class-svm': from Anomaly_method.One_Class_SVM import OneClassSVMConfig # 获取默认配置的值(不创建完整的配置对象以避免验证) config_kwargs = { 'input_path': args.input, 'output_dir': args.output_dir, 'kernel': getattr(args, 'ocsvm_kernel', 'rbf'), 'nu': getattr(args, 'ocsvm_nu', 0.1), 'gamma': getattr(args, 'ocsvm_gamma', None), 'degree': getattr(args, 'ocsvm_degree', 3), 'coef0': getattr(args, 'ocsvm_coef0', 0.0), 'tol': getattr(args, 'ocsvm_tol', 1e-3), 'shrinking': getattr(args, 'ocsvm_shrinking', True), 'cache_size': getattr(args, 'ocsvm_cache_size', 200.0), 'max_iter': getattr(args, 'ocsvm_max_iter', -1), 'use_scaling': getattr(args, 'ocsvm_use_scaling', True), 'use_pca': getattr(args, 'ocsvm_use_pca', True), 'n_components': getattr(args, 'ocsvm_n_components', None), 'use_grid_search': getattr(args, 'ocsvm_use_grid_search', False), 'cv_folds': getattr(args, 'ocsvm_cv_folds', 3) } return OneClassSVMConfig(**config_kwargs) elif args.method == 'rx': from Anomaly_method.RX import RXConfig config_kwargs = { 'input_path': args.input, 'background_path': getattr(args, 'background', None), 'label_col': getattr(args, 'label_col', None), 'spectral_start': getattr(args, 'spectral_start', None), 'spectral_end': getattr(args, 'spectral_end', None), 'output_dir': args.output_dir, 'background_model': getattr(args, 'background_model', 'global'), 'window_size': getattr(args, 'window_size', 5), 'row_window': getattr(args, 'row_window', 3), 'threshold': getattr(args, 'rx_threshold', None), 'contamination': getattr(args, 'contamination', 0.1), 'use_robust_covariance': getattr(args, 'robust_covariance', False) } return RXConfig(**config_kwargs) elif args.method == 'squared-loss-probability': from Anomaly_method.squared_loss_probability import SquaredLossAnomalyConfig config_kwargs = { 'input_path': args.input, 'output_dir': args.output_dir, 'reconstruction_method': getattr(args, 'reconstruction_method', 'linear'), 'n_components': getattr(args, 'slp_n_components', None), 'target_band': getattr(args, 'target_band', None), 'probability_flag': getattr(args, 'probability_mode', True), 'threshold': getattr(args, 'slp_threshold', 0.8), 'use_scaling': getattr(args, 'slp_use_scaling', True), 'scaler_type': getattr(args, 'scaler_type', 'robust'), 'fit_intercept': getattr(args, 'fit_intercept', True) } return SquaredLossAnomalyConfig(**config_kwargs) else: raise ValueError(f"不支持的异常检测方法: {args.method}") class ClassificationArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--method', '-m', default='svm', choices=['svm', 'random_forest', 'knn', 'logistic_regression', 'linear_discriminant', 'quadratic_discriminant', 'plsda', 'xgboost', 'lightgbm', 'catboost'], help='分类方法') subparser.add_argument('--roi-file', '-r', required=True, help='ROI标注文件路径 (.xml)') subparser.add_argument('--test-size', '-ts', type=float, default=0.3, help='测试集比例') # 数据预处理参数 subparser.add_argument('--use-standardization', action='store_true', default=True, help='是否使用标准化') subparser.add_argument('--batch-size', type=int, default=10000, help='批处理大小') # SVM专用参数 subparser.add_argument('--svm-c', type=float, default=1.0, help='SVM正则化参数C (SVM方法)') subparser.add_argument('--svm-kernel', choices=['linear', 'rbf', 'poly', 'sigmoid'], default='rbf', help='SVM核函数 (SVM方法)') subparser.add_argument('--svm-gamma', help='SVM核函数参数gamma: "scale", "auto" 或数值 (SVM方法,默认scale)') # 随机森林专用参数 subparser.add_argument('--rf-n-estimators', type=int, default=100, help='树的数量 (随机森林方法)') subparser.add_argument('--rf-max-depth', type=int, help='树的最大深度 (随机森林方法)') subparser.add_argument('--rf-min-samples-split', type=int, default=2, help='内部节点再划分所需最小样本数 (随机森林方法)') # KNN专用参数 subparser.add_argument('--knn-n-neighbors', type=int, default=5, help='邻居数量 (KNN方法)') subparser.add_argument('--knn-weights', choices=['uniform', 'distance'], default='uniform', help='权重函数 (KNN方法)') subparser.add_argument('--knn-algorithm', choices=['auto', 'ball_tree', 'kd_tree', 'brute'], default='auto', help='算法 (KNN方法)') # 逻辑回归专用参数 subparser.add_argument('--lr-c', type=float, default=1.0, help='正则化强度 (逻辑回归方法)') subparser.add_argument('--lr-max-iter', type=int, default=100, help='最大迭代次数 (逻辑回归方法)') # PLS-DA专用参数 subparser.add_argument('--pls-n-components', type=int, default=10, help='PLS成分数量 (PLS-DA方法)') # XGBoost专用参数 subparser.add_argument('--xgb-n-estimators', type=int, default=100, help='树的数量 (XGBoost方法)') subparser.add_argument('--xgb-max-depth', type=int, default=6, help='树的最大深度 (XGBoost方法)') subparser.add_argument('--xgb-learning-rate', type=float, default=0.3, help='学习率 (XGBoost方法)') # LightGBM专用参数 subparser.add_argument('--lgb-n-estimators', type=int, default=100, help='树的数量 (LightGBM方法)') subparser.add_argument('--lgb-max-depth', type=int, default=-1, help='树的最大深度 (LightGBM方法,-1表示无限制)') subparser.add_argument('--lgb-learning-rate', type=float, default=0.1, help='学习率 (LightGBM方法)') # CatBoost专用参数 subparser.add_argument('--cb-n-estimators', type=int, default=100, help='树的数量 (CatBoost方法)') subparser.add_argument('--cb-depth', type=int, default=6, help='树的最大深度 (CatBoost方法)') subparser.add_argument('--cb-learning-rate', type=float, default=0.03, help='学习率 (CatBoost方法)') def build_config(self, args): from classfication_method.classfication import ClassificationConfig model_params = {} if args.method == 'svm': model_params['C'] = getattr(args, 'svm_c', 1.0) model_params['kernel'] = getattr(args, 'svm_kernel', 'rbf') # 处理gamma参数:如果未指定,使用默认值'scale' if hasattr(args, 'svm_gamma') and args.svm_gamma is not None: # 如果提供了值,尝试转换为适当的类型 gamma_str = str(args.svm_gamma).lower() if gamma_str in ['scale', 'auto']: model_params['gamma'] = gamma_str else: try: model_params['gamma'] = float(args.svm_gamma) except ValueError: model_params['gamma'] = 'scale' # 如果转换失败,使用默认值 else: model_params['gamma'] = 'scale' # 默认值 elif args.method == 'rf': model_params['n_estimators'] = getattr(args, 'rf_n_estimators', 100) model_params['max_depth'] = getattr(args, 'rf_max_depth', None) model_params['min_samples_split'] = getattr(args, 'rf_min_samples_split', 2) elif args.method == 'knn': model_params['n_neighbors'] = getattr(args, 'knn_n_neighbors', 5) model_params['weights'] = getattr(args, 'knn_weights', 'uniform') model_params['algorithm'] = getattr(args, 'knn_algorithm', 'auto') elif args.method == 'lr': model_params['C'] = getattr(args, 'lr_c', 1.0) model_params['max_iter'] = getattr(args, 'lr_max_iter', 100) elif args.method == 'pls-da': model_params['n_components'] = getattr(args, 'pls_n_components', 10) elif args.method == 'xgboost': model_params['n_estimators'] = getattr(args, 'xgb_n_estimators', 100) model_params['max_depth'] = getattr(args, 'xgb_max_depth', 6) model_params['learning_rate'] = getattr(args, 'xgb_learning_rate', 0.3) elif args.method == 'lightgbm': model_params['n_estimators'] = getattr(args, 'lgb_n_estimators', 100) model_params['max_depth'] = getattr(args, 'lgb_max_depth', -1) model_params['learning_rate'] = getattr(args, 'lgb_learning_rate', 0.1) elif args.method == 'catboost': model_params['iterations'] = getattr(args, 'cb_n_estimators', 100) model_params['depth'] = getattr(args, 'cb_depth', 6) model_params['learning_rate'] = getattr(args, 'cb_learning_rate', 0.03) # 确定输出路径 if hasattr(args, 'output_file') and args.output_file: output_path = args.output_file else: # 如果没有指定output_file,基于output_dir和output_prefix生成 import os output_prefix = getattr(args, 'output_prefix', 'result') output_path = os.path.join(args.output_dir, f"{output_prefix}.dat") # 模型类型映射(从命令行参数映射到ClassificationConfig期望的名称) model_type_mapping = { 'svm': 'svm', 'random_forest': 'random_forest', 'knn': 'knn', 'logistic_regression': 'logistic_regression', 'linear_discriminant': 'linear_discriminant', 'quadratic_discriminant': 'quadratic_discriminant', 'plsda': 'plsda', 'xgboost': 'xgboost', 'lightgbm': 'lightgbm', 'catboost': 'catboost' } config_model_type = model_type_mapping.get(args.method, args.method) # 构建配置参数 config_kwargs = { 'hyperspectral_path': args.input, 'roi_path': args.roi_file, 'model_type': config_model_type, 'model_params': model_params, 'test_size': args.test_size, 'output_path': output_path, 'use_standardization': getattr(args, 'use_standardization', True), 'batch_size': getattr(args, 'batch_size', 10000) } return ClassificationConfig(**config_kwargs) class ClusteringArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--method', '-m', default='kmeans', choices=['kmeans', 'fuzzy-cmeans', 'gmm', 'hierarchical', 'dbscan', 'spectral', 'subspace', 'ensemble'], help='聚类方法') subparser.add_argument('--n-clusters', type=int, default=5, help='聚类数量') subparser.add_argument('--batch', action='store_true', help='运行所有聚类方法') # 数据预处理参数 subparser.add_argument('--use-scaling', action='store_true', default=True, help='是否使用数据标准化') subparser.add_argument('--scaler-type', choices=['standard', 'minmax'], default='standard', help='标准化类型') subparser.add_argument('--use-pca', action='store_true', help='是否使用PCA降维') subparser.add_argument('--pca-components', type=int, help='PCA降维后的维度数') # K-means专用参数 subparser.add_argument('--kmeans-init', choices=['k-means++', 'random'], default='k-means++', help='初始化方法 (K-means方法)') subparser.add_argument('--kmeans-n-init', type=int, default=10, help='运行次数,选择最佳结果 (K-means方法)') subparser.add_argument('--kmeans-max-iter', type=int, default=300, help='最大迭代次数 (K-means方法)') subparser.add_argument('--kmeans-batch-size', type=int, default=1000, help='批处理大小 (K-means方法)') # 模糊C均值专用参数 subparser.add_argument('--fcm-m', type=float, default=2.0, help='模糊度参数 (模糊C均值方法,默认2.0)') subparser.add_argument('--fcm-error', type=float, default=0.005, help='终止误差 (模糊C均值方法)') subparser.add_argument('--fcm-maxiter', type=int, default=1000, help='最大迭代次数 (模糊C均值方法)') # GMM专用参数 subparser.add_argument('--gmm-covariance-type', choices=['full', 'tied', 'diag', 'spherical'], default='full', help='协方差类型 (GMM方法)') subparser.add_argument('--gmm-max-iter', type=int, default=100, help='最大迭代次数 (GMM方法)') # 层次聚类专用参数 subparser.add_argument('--hier-linkage', choices=['ward', 'complete', 'average', 'single'], default='ward', help='连接方法 (层次聚类方法)') subparser.add_argument('--hier-affinity', choices=['euclidean', 'l1', 'l2', 'manhattan', 'cosine'], default='euclidean', help='距离度量 (层次聚类方法)') # DBSCAN专用参数 subparser.add_argument('--dbscan-eps', type=float, default=0.5, help='邻域半径 (DBSCAN方法)') subparser.add_argument('--dbscan-min-samples', type=int, default=5, help='核心点的最小样本数 (DBSCAN方法)') subparser.add_argument('--dbscan-algorithm', choices=['auto', 'ball_tree', 'kd_tree', 'brute'], default='auto', help='算法 (DBSCAN方法)') # 谱聚类专用参数 subparser.add_argument('--spectral-affinity', choices=['rbf', 'nearest_neighbors'], default='rbf', help='相似度矩阵构建方法 (谱聚类方法)') subparser.add_argument('--spectral-gamma', type=float, default=1.0, help='RBF核函数的gamma参数 (谱聚类方法)') # 子空间聚类专用参数 subparser.add_argument('--subspace-n-subspaces', type=int, help='子空间数量 (子空间聚类方法,默认等于聚类数)') # 集成聚类专用参数 subparser.add_argument('--ensemble-voting', choices=['hard', 'soft'], default='hard', help='投票方式 (集成聚类方法)') def build_config(self, args): from cluster_method.cluster import ClusteringConfig # 如果是批量模式,使用默认配置 if hasattr(args, 'batch') and args.batch: return ClusteringConfig( input_path=args.input, method='batch', # 特殊标记,表示批量处理 n_clusters=args.n_clusters, output_dir=args.output_dir ) # 构建基础配置 config_kwargs = { 'input_path': args.input, 'method': args.method, 'n_clusters': args.n_clusters, 'output_dir': args.output_dir, 'use_scaling': getattr(args, 'use_scaling', True), 'scaler_type': getattr(args, 'scaler_type', 'standard'), 'use_pca': getattr(args, 'use_pca', False), 'pca_components': getattr(args, 'pca_components', None) } # 构建方法特定参数 method_params = {} if args.method == 'kmeans': method_params['init'] = getattr(args, 'kmeans_init', 'k-means++') method_params['n_init'] = getattr(args, 'kmeans_n_init', 10) method_params['max_iter'] = getattr(args, 'kmeans_max_iter', 300) method_params['batch_size'] = getattr(args, 'kmeans_batch_size', 1000) elif args.method == 'fuzzy-cmeans': method_params['fuzziness'] = getattr(args, 'fcm_m', 2.0) method_params['error'] = getattr(args, 'fcm_error', 0.005) method_params['max_iter'] = getattr(args, 'fcm_maxiter', 1000) elif args.method == 'gmm': method_params['covariance_type'] = getattr(args, 'gmm_covariance_type', 'full') method_params['max_iter'] = getattr(args, 'gmm_max_iter', 200) method_params['n_init_attempts'] = getattr(args, 'gmm_n_init_attempts', 10) elif args.method == 'hierarchical': method_params['linkage'] = getattr(args, 'hier_linkage', 'ward') method_params['affinity'] = getattr(args, 'hier_affinity', 'euclidean') elif args.method == 'dbscan': method_params['eps_percentile'] = getattr(args, 'dbscan_eps_percentile', 50) method_params['min_samples_factor'] = getattr(args, 'dbscan_min_samples_factor', 0.1) method_params['n_neighbors'] = getattr(args, 'dbscan_n_neighbors', 20) elif args.method == 'spectral': method_params['affinity'] = getattr(args, 'spectral_affinity', 'nearest_neighbors') method_params['n_neighbors'] = getattr(args, 'spectral_n_neighbors', 10) method_params['large_dataset_threshold'] = getattr(args, 'spectral_large_dataset_threshold', 2000) elif args.method == 'subspace': method_params['n_components_factor'] = getattr(args, 'subspace_n_components_factor', 0.33) method_params['max_iter'] = getattr(args, 'subspace_max_iter', 300) elif args.method == 'ensemble': method_params['voting'] = getattr(args, 'ensemble_voting', 'hard') if method_params: config_kwargs['method_params'] = method_params return ClusteringConfig(**config_kwargs) class FeatureSelectionArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--method', '-m', default='Spa', choices=['Spa', 'Cars', 'Uve', 'GA', 'ReliefF', 'SiPLS', 'Lars', 'RandomFrog'], help='特征选择方法') subparser.add_argument('--label-column', '-l', required=True, help='标签列名') subparser.add_argument('--n-features', type=int, default=20, help='选择的特征数量') subparser.add_argument('--spectral-columns', help='光谱数据列范围,如 "1:10" 或 "auto" (不指定则自动检测)') # 输出和可视化参数 subparser.add_argument('--output-csv', action='store_true', default=True, help='是否输出CSV结果文件') subparser.add_argument('--save-plots', action='store_true', default=True, help='是否保存可视化图表') subparser.add_argument('--plot-prefix', default='feature_selection', help='图表文件名前缀') # SPA专用参数 subparser.add_argument('--spa-min-vars', type=int, default=2, help='SPA最小变量数') subparser.add_argument('--spa-max-vars', type=int, default=50, help='SPA最大变量数') subparser.add_argument('--spa-autoscaling', type=int, choices=[0, 1], default=1, help='SPA自动缩放 (0=关闭, 1=开启)') # CARS专用参数 subparser.add_argument('--cars-n', type=int, default=50, help='CARS蒙特卡洛采样次数') subparser.add_argument('--cars-cv', type=int, default=10, help='CARS交叉验证折数') # UVE专用参数 subparser.add_argument('--uve-ncomp', type=int, default=20, help='UVE PLS成分数') subparser.add_argument('--uve-cv', type=int, default=5, help='UVE交叉验证折数') # GA专用参数 subparser.add_argument('--ga-population-size', type=int, default=10, help='GA种群大小') subparser.add_argument('--ga-n-generations', type=int, default=50, help='GA进化代数') subparser.add_argument('--ga-crossover-rate', type=float, default=0.8, help='GA交叉概率') subparser.add_argument('--ga-mutation-rate', type=float, default=0.1, help='GA变异概率') # Relief-F专用参数 subparser.add_argument('--relief-k', type=int, default=10, help='Relief-F最近邻数k') subparser.add_argument('--relief-sample-size', type=int, help='Relief-F采样大小 (默认使用全部样本)') # SiPLS专用参数 subparser.add_argument('--sipls-interval-width', type=int, default=10, help='SiPLS区间宽度') subparser.add_argument('--sipls-cv', type=int, default=5, help='SiPLS交叉验证折数') # LAR专用参数 subparser.add_argument('--lar-cv', type=int, default=5, help='LAR交叉验证折数') subparser.add_argument('--lar-max-features', type=int, help='LAR最大特征数 (默认无限制)') # RandomFrog专用参数 subparser.add_argument('--random-frog-n-frogs', type=int, default=50, help='RandomFrog青蛙数量') subparser.add_argument('--random-frog-n-memeplexes', type=int, default=5, help='RandomFrog模因数量') subparser.add_argument('--random-frog-n-evolution-steps', type=int, default=10, help='RandomFrog进化步数') subparser.add_argument('--random-frog-n-shuffle-iterations', type=int, default=10, help='RandomFrog重排迭代次数') subparser.add_argument('--random-frog-cv', type=int, default=5, help='RandomFrog交叉验证折数') def _parse_spectral_columns(self, csv_file_path: str, spectral_columns_arg: str, label_column: str) -> list: """ 解析光谱列参数,返回列名列表 Args: csv_file_path: CSV文件路径 spectral_columns_arg: 光谱列参数字符串,如 "1:165" 或 "auto" label_column: 标签列名 Returns: 光谱列名列表 """ import pandas as pd # 读取CSV文件的列名 df = pd.read_csv(csv_file_path, nrows=0) all_columns = df.columns.tolist() # 找到标签列的索引 if label_column not in all_columns: raise ValueError(f"标签列 '{label_column}' 不存在于CSV文件中") label_index = all_columns.index(label_column) # 获取除标签列外的所有列 spectral_candidates = [col for i, col in enumerate(all_columns) if i != label_index] if spectral_columns_arg is None or spectral_columns_arg.lower() == 'auto': # 自动检测:使用除标签列外的所有列 return spectral_candidates # 解析范围字符串 try: columns = [] # 分割多个范围(用逗号分隔) for part in spectral_columns_arg.split(','): part = part.strip() if ':' in part: # 范围选择,如 "1:165" start, end = part.split(':') start = int(start.strip()) end = int(end.strip()) # 转换为从标签列开始的索引 # 注意:用户指定的索引是从1开始的(第一列光谱数据为1) if start < 1 or end < 1: raise ValueError(f"列范围 {start}:{end} 无效,必须从1开始") # 检查范围是否有效 if end > len(spectral_candidates): end = len(spectral_candidates) if start > len(spectral_candidates): raise ValueError(f"起始列 {start} 超出可用光谱列范围 [1, {len(spectral_candidates)}]") # 获取指定范围的列名 selected_columns = spectral_candidates[start-1:end] # Python索引从0开始 columns.extend(selected_columns) else: # 单个索引 idx = int(part.strip()) if idx < 1 or idx > len(spectral_candidates): raise ValueError(f"列索引 {idx} 超出范围 [1, {len(spectral_candidates)}]") columns.append(spectral_candidates[idx-1]) return list(dict.fromkeys(columns)) # 去重并保持顺序 except ValueError as e: raise ValueError(f"解析光谱列参数 '{spectral_columns_arg}' 时出错: {e}") def build_config(self, args): from Feature_Selection_method.feture_select import FeatureSelectionConfig import pandas as pd # 构建方法参数 method_params = {} if args.method == 'Spa': method_params['m_min'] = getattr(args, 'spa_min_vars', 2) method_params['m_max'] = getattr(args, 'spa_max_vars', 50) method_params['autoscaling'] = getattr(args, 'spa_autoscaling', 1) elif args.method == 'Cars': method_params['N'] = getattr(args, 'cars_n', 50) method_params['f'] = getattr(args, 'n_features', 20) method_params['cv'] = getattr(args, 'cars_cv', 10) elif args.method == 'Uve': method_params['ncomp'] = getattr(args, 'uve_ncomp', 20) method_params['cv'] = getattr(args, 'uve_cv', 5) elif args.method == 'GA': method_params['population_size'] = getattr(args, 'ga_population_size', 10) method_params['n_generations'] = getattr(args, 'ga_n_generations', 50) method_params['crossover_rate'] = getattr(args, 'ga_crossover_rate', 0.8) method_params['mutation_rate'] = getattr(args, 'ga_mutation_rate', 0.1) elif args.method == 'ReliefF': method_params['k'] = getattr(args, 'relief_k', 10) method_params['sample_size'] = getattr(args, 'relief_sample_size', None) elif args.method == 'SiPLS': method_params['interval_width'] = getattr(args, 'sipls_interval_width', 10) method_params['cv'] = getattr(args, 'sipls_cv', 5) elif args.method == 'Lars': method_params['cv'] = getattr(args, 'lar_cv', 5) method_params['max_features'] = getattr(args, 'lar_max_features', None) elif args.method == 'RandomFrog': # RandomFrog参数设置 method_params['n_frogs'] = getattr(args, 'random_frog_n_frogs', 50) method_params['n_memeplexes'] = getattr(args, 'random_frog_n_memeplexes', 5) method_params['n_evolution_steps'] = getattr(args, 'random_frog_n_evolution_steps', 10) method_params['n_shuffle_iterations'] = getattr(args, 'random_frog_n_shuffle_iterations', 10) method_params['cv'] = getattr(args, 'random_frog_cv', 5) # 解析光谱列 spectral_columns = self._parse_spectral_columns(args.input, getattr(args, 'spectral_columns', None), args.label_column) # 构建配置参数 config_kwargs = { 'csv_file_path': args.input, 'label_column': args.label_column, 'method': args.method, 'output_dir': args.output_dir, 'method_params': method_params, 'spectral_columns': spectral_columns, 'output_csv': getattr(args, 'output_csv', True), 'save_plots': getattr(args, 'save_plots', True), 'plot_name_prefix': getattr(args, 'plot_prefix', 'feature_selection') } return FeatureSelectionConfig(**config_kwargs) class SpectralIndexArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--indices', '-idx', nargs='+', help='要计算的光谱指数列表 (默认: 全部)') subparser.add_argument('--batch', action='store_true', help='计算所有指数') subparser.add_argument('--png', action='store_true', help='生成PNG可视化') # 数据格式参数 subparser.add_argument('--data-format', choices=['csv', 'envi', 'auto'], default='auto', help='数据文件格式') subparser.add_argument('--label-column', help='标签列名 (CSV格式时需要)') subparser.add_argument('--spectral-start', help='光谱数据起始列名或波段 (CSV格式时需要)') subparser.add_argument('--spectral-end', help='光谱数据结束列名或波段 (CSV格式时需要)') # 输出参数 subparser.add_argument('--spectral-output-format', choices=['csv', 'excel', 'both'], default='both', help='输出文件格式') subparser.add_argument('--spectral-save-individual', action='store_true', help='是否保存单个指数结果') subparser.add_argument('--spectral-output-prefix', default='spectral_indices', help='输出文件名前缀') # 可视化参数 subparser.add_argument('--plot-histogram', action='store_true', default=True, help='是否绘制直方图') subparser.add_argument('--plot-boxplot', action='store_true', default=True, help='是否绘制箱线图') subparser.add_argument('--plot-correlation', action='store_true', help='是否绘制相关性热力图') subparser.add_argument('--plot-wavelength', action='store_true', help='是否按波长绘制指数分布') # 统计参数 subparser.add_argument('--calculate-stats', action='store_true', default=True, help='是否计算统计信息') subparser.add_argument('--stats-percentiles', nargs='+', type=float, help='要计算的百分位数列表 (默认: 25, 50, 75)') # 指数计算参数 subparser.add_argument('--index-csv', help='自定义光谱指数定义CSV文件路径') subparser.add_argument('--formula-csv', help='自定义公式定义CSV文件路径') def build_config(self, args): from spectral_feature_method.spectral_index import SpectralIndexConfig # 创建基础配置 config = SpectralIndexConfig.create_quick_analysis( data_file_path=args.input, indices_to_calculate=args.indices ) # 更新数据配置 config.data.data_format = getattr(args, 'data_format', 'auto') config.data.label_column = getattr(args, 'label_column', None) config.data.spectral_start = getattr(args, 'spectral_start', None) config.data.spectral_end = getattr(args, 'spectral_end', None) # 更新输出配置 config.output.output_format = getattr(args, 'spectral_output_format', 'both') config.output.save_individual_indices = getattr(args, 'spectral_save_individual', False) config.output.output_prefix = getattr(args, 'spectral_output_prefix', 'spectral_indices') # 更新可视化配置 config.output.plot_histogram = getattr(args, 'plot_histogram', True) config.output.plot_boxplot = getattr(args, 'plot_boxplot', True) config.output.plot_correlation = getattr(args, 'plot_correlation', False) config.output.plot_wavelength = getattr(args, 'plot_wavelength', False) # 更新统计配置 config.output.calculate_statistics = getattr(args, 'calculate_stats', True) if hasattr(args, 'stats_percentiles'): config.output.statistics_percentiles = args.stats_percentiles # 更新指数配置 if hasattr(args, 'index_csv'): config.indices.spectral_index_csv = args.index_csv if hasattr(args, 'formula_csv'): config.indices.formula_csv = args.formula_csv return config class PreprocessingArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--method', '-m', default='SS', choices=['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave'], help='预处理方法') subparser.add_argument('--spectral-start-index', type=int, default=1, help='CSV文件的谱段起始列索引 (从0开始,默认: 1)') subparser.add_argument('--handle-outliers', action='store_true', help='处理异常值') subparser.add_argument('--outlier-method', default='iqr', choices=['iqr', 'isolation-forest', 'lof'], help='异常值检测方法') # CSV文件验证参数 subparser.add_argument('--validate-csv', action='store_true', help='启用CSV文件验证') subparser.add_argument('--min-rows', type=int, default=1, help='CSV文件最小行数 (默认: 1)') subparser.add_argument('--min-cols', type=int, default=2, help='CSV文件最小列数 (默认: 2)') subparser.add_argument('--check-missing-values', action='store_true', help='检查缺失值') subparser.add_argument('--check-data-types', action='store_true', help='检查数据类型一致性') subparser.add_argument('--wavelength-column', help='波长列名 (用于验证波长数据)') # MA方法参数 subparser.add_argument('--ma-window', type=int, default=11, help='移动平均窗口大小 (默认: 11)') # SG方法参数 subparser.add_argument('--sg-window', type=int, default=15, help='Savitzky-Golay窗口大小 (默认: 15)') subparser.add_argument('--sg-poly', type=int, default=2, help='Savitzky-Golay多项式阶数 (默认: 2)') def build_config(self, args): from preprocessing_method.Preprocessing import PreprocessingConfig return PreprocessingConfig( input_path=args.input, method=args.method, spectral_start_index=getattr(args, 'spectral_start_index', 1), handle_outliers=args.handle_outliers, outlier_method=args.outlier_method, output_dir=args.output_dir, # CSV验证参数 validate_csv=getattr(args, 'validate_csv', False), min_rows=getattr(args, 'min_rows', 1), min_cols=getattr(args, 'min_cols', 2), check_missing_values=getattr(args, 'check_missing_values', False), check_data_types=getattr(args, 'check_data_types', False), wavelength_column=getattr(args, 'wavelength_column', None), # MA和SG方法参数 ma_window=getattr(args, 'ma_window', 11), sg_window=getattr(args, 'sg_window', 15), sg_poly=getattr(args, 'sg_poly', 2) ) class ShapeFeatureArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): # 输入数据参数 subparser.add_argument('--input-type', default='dat', choices=['dat', 'shp'], help='输入类型 (默认: dat)') subparser.add_argument('--hdr-file', help='对应的HDR头文件路径(仅在输入类型为dat时使用)') subparser.add_argument('--shape-band-index', type=int, default=0, help='波段索引 (默认: 0)') # 连通域处理参数 subparser.add_argument('--connectivity', type=int, choices=[4, 8], default=8, help='连通性:4或8 (默认: 8)') subparser.add_argument('--min-area', type=int, default=10, help='最小区域面积阈值 (默认: 10)') # 分水岭算法参数 subparser.add_argument('--use-watershed', action='store_true', help='使用分水岭算法分离相邻物体') subparser.add_argument('--watershed-min-distance', type=int, default=10, help='分水岭算法中局部最大值的最小距离 (默认: 10)') # 输出设置参数 subparser.add_argument('--output-csv', action='store_true', default=True, help='输出CSV文件 (默认: True)') subparser.add_argument('--save-labeled-image', action='store_true', help='保存标记后的图像') def build_config(self, args): from spatial_features_method.shape_feature import ShapeFeatureConfig config = ShapeFeatureConfig() config.input_type = args.input_type config.dat_file_path = args.input if args.input_type == 'dat' else None config.hdr_file_path = getattr(args, 'hdr_file', None) config.shp_file_path = args.input if args.input_type == 'shp' else None config.band_index = args.shape_band_index config.connectivity = getattr(args, 'connectivity', 8) config.min_area = args.min_area config.use_watershed = args.use_watershed config.watershed_min_distance = getattr(args, 'watershed_min_distance', 10) config.output_dir = args.output_dir config.output_csv = getattr(args, 'output_csv', True) config.save_labeled_image = getattr(args, 'save_labeled_image', False) return config class GLCMArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): # GLCM参数 subparser.add_argument('--nbit', type=int, default=64, help='灰度级数 (默认: 64)') subparser.add_argument('--min-gray', type=int, default=0, help='最小灰度值 (默认: 0)') subparser.add_argument('--max-gray', type=int, default=255, help='最大灰度值 (默认: 255)') subparser.add_argument('--slide-window', type=int, default=7, help='滑动窗口大小,必须为奇数 (默认: 7)') # GLCM计算参数 subparser.add_argument('--step', type=int, nargs='+', default=[2], help='步长距离列表 (默认: [2])') subparser.add_argument('--angle', type=float, nargs='+', default=[0], help='角度列表(弧度) (默认: [0])') # 数据处理参数 subparser.add_argument('--band-index', '-b', type=int, default=25, help='要处理的波段索引 (默认: 25)') # 输出设置 subparser.add_argument('--save-dat', action='store_true', default=True, help='保存为DAT文件 (默认: True)') subparser.add_argument('--save-png', action='store_true', help='保存为PNG文件') def build_config(self, args): from spatial_features_method.glcm import SpatialFeatureConfig config = SpatialFeatureConfig() config.nbit = args.nbit config.mi = args.min_gray config.ma = args.max_gray config.slide_window = args.slide_window config.step = args.step config.angle = args.angle config.band_index = args.band_index config.image_path = args.input # 使用命令行输入的文件路径 config.output_dir = args.output_dir config.save_dat = args.save_dat config.save_png = args.save_png return config class ColorAnalysisArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--method', '-m', default='CIEDE2000', choices=['CIE76', 'CIE94', 'CIEDE2000'], help='色差计算方法') subparser.add_argument('--standards-file', '-s', required=True, help='标准颜色文件路径 (.csv)') def build_config(self, args): # 为 DeltaEApp 创建一个模拟的 args 对象 class MockArgs: def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) return MockArgs( input=args.input, standards_file=args.standards_file, method=args.method, output_dir=args.output_dir, mode='image' # 默认使用图像模式 ) class FilterArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--filter-type', default='mean', choices=['mean', 'median', 'gaussian', 'bilateral', 'opening', 'closing', 'erosion', 'dilation', 'gradient', 'tophat', 'blackhat'], help='滤波类型') subparser.add_argument('--filter-band-index', type=int, default=50, help='波段索引') subparser.add_argument('--kernel-size', '-k', type=int, default=3, help='核大小') # Smooth_filter 特有参数 subparser.add_argument('--sigma', '-s', type=float, default=1.0, help='高斯标准差 (仅用于gaussian滤波)') subparser.add_argument('--sigma-color', type=float, default=75.0, help='颜色空间标准差 (仅用于bilateral滤波)') subparser.add_argument('--sigma-space', type=float, default=75.0, help='坐标空间标准差 (仅用于bilateral滤波)') # morphological_fliter 特有参数 subparser.add_argument('--se-shape', default='disk', choices=['disk', 'square', 'rectangle', 'diamond'], help='结构元素形状 (仅用于形态学操作)') subparser.add_argument('--se-size', type=int, default=3, help='结构元素大小 (仅用于形态学操作)') subparser.add_argument('--format-type', default='auto', choices=['auto', 'bil', 'bip', 'bsq', 'dat'], help='图像格式 (仅用于形态学操作)') def build_config(self, args): # 根据滤波类型选择合适的配置类 if args.filter_type in ['mean', 'median', 'gaussian', 'bilateral']: from fliter_method.Smooth_filter import SmoothFilterConfig return SmoothFilterConfig( input_path=args.input, filter_type=args.filter_type, band_index=args.filter_band_index, kernel_size=args.kernel_size, output_dir=args.output_dir, sigma=getattr(args, 'sigma', 1.0), sigma_color=getattr(args, 'sigma_color', 75.0), sigma_space=getattr(args, 'sigma_space', 75.0) ) else: from fliter_method.morphological_fliter import MorphologicalFilterConfig return MorphologicalFilterConfig( input_path=args.input, operation=args.filter_type, band_index=args.filter_band_index, kernel_size=args.kernel_size, output_dir=args.output_dir, se_shape=getattr(args, 'se_shape', 'disk'), se_size=getattr(args, 'se_size', 3), format_type=getattr(args, 'format_type', 'auto') ) class RegressionArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): # 数据配置参数 subparser.add_argument('--label-column', '-l', required=True, help='标签列名') subparser.add_argument('--test-size', type=float, default=0.2, help='测试集比例 (默认: 0.2)') subparser.add_argument('--random-state', type=int, default=42, help='随机种子 (默认: 42)') subparser.add_argument('--scale-method', choices=['standard', 'minmax'], default='standard', help='数据标准化方法 (默认: standard)') # 模型配置参数 subparser.add_argument('--models', '-m', nargs='+', help='回归模型列表 (默认: 全部)') subparser.add_argument('--tune-params', action='store_true', help='启用超参数调优') subparser.add_argument('--tuning-method', choices=['grid', 'random'], default='grid', help='调优方法 (默认: grid)') subparser.add_argument('--cv-folds', type=int, default=5, help='交叉验证折数 (默认: 5)') subparser.add_argument('--random-search-iter', type=int, default=20, help='随机搜索迭代次数 (默认: 20)') # 训练配置参数 subparser.add_argument('--epochs', type=int, default=100, help='训练轮数 (默认: 100)') subparser.add_argument('--batch-size', type=int, default=32, help='批大小 (默认: 32)') subparser.add_argument('--learning-rate', type=float, default=0.001, help='学习率 (默认: 0.001)') # 输出配置参数 subparser.add_argument('--save-models', action='store_true', default=True, help='保存训练好的模型') subparser.add_argument('--plot-results', action='store_true', default=True, help='生成结果可视化图表') subparser.add_argument('--save-dir', default='models', help='模型保存目录 (默认: models)') subparser.add_argument('--plot-dir', default='plots', help='图表保存目录 (默认: plots)') def build_config(self, args): from rgression_method.regression import RegressionConfig # 创建基础配置 config = RegressionConfig.create_default( csv_path=args.input, label_column=args.label_column ) # 设置模型名称 if getattr(args, 'models', None): config.models.model_names = args.models # 根据是否启用调参设置其他参数 if getattr(args, 'tune_params', False): # 启用调参的配置 config.models.tune_hyperparams = True config.models.tuning_method = getattr(args, 'tuning_method', 'grid') config.models.cv_folds = getattr(args, 'cv_folds', 5) config.models.random_search_iter = getattr(args, 'random_search_iter', 20) config.training.epochs = getattr(args, 'epochs', 100) config.training.batch_size = getattr(args, 'batch_size', 32) config.training.learning_rate = getattr(args, 'learning_rate', 0.001) else: # 快速分析配置 config.models.tune_hyperparams = False # 设置通用参数 config.data.test_size = getattr(args, 'test_size', 0.2) config.data.random_state = getattr(args, 'random_state', 42) config.data.scale_method = getattr(args, 'scale_method', 'standard') config.output.save_models = getattr(args, 'save_models', True) config.output.plot_results = getattr(args, 'plot_results', True) config.output.save_dir = getattr(args, 'save_dir', 'models') config.output.plot_dir = getattr(args, 'plot_dir', 'plots') return config class RegressionPredictionArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): # 必需参数 subparser.add_argument('--model-path', '-m', required=True, help='模型文件路径(单个文件或目录)') # 可选参数 subparser.add_argument('--mask-path', help='遮罩文件路径') subparser.add_argument('--use-mask', action='store_true', default=True, help='是否使用遮罩 (默认: True)') subparser.add_argument('--batch-mode', action='store_true', help='批量处理模式') subparser.add_argument('--colormap', default='viridis', choices=['viridis', 'plasma', 'inferno', 'magma', 'cividis', 'Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds', 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu', 'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn'], help='颜色映射 (默认: viridis)') subparser.add_argument('--dpi', type=int, default=300, help='输出图像DPI (默认: 300)') subparser.add_argument('--save-individual', action='store_true', default=True, help='保存单个模型预测结果') def build_config(self, args): from rgression_method.regression_predict import PredictionConfig return PredictionConfig( image_path=args.input, mask_path=getattr(args, 'mask_path', None), model_path=args.model_path, output_dir=args.output_dir, use_mask=getattr(args, 'use_mask', True), batch_mode=getattr(args, 'batch_mode', False), colormap=getattr(args, 'colormap', 'viridis'), dpi=getattr(args, 'dpi', 300), save_individual=getattr(args, 'save_individual', True) ) class DeltaEArgsBuilder(ArgsBuilder): def add_common_args(self, subparser): """DeltaE使用自己的参数命名""" pass def add_module_specific_args(self, subparser): # DeltaE 特定的参数命名 subparser.add_argument('--mode', required=True, choices=['image', 'pairwise'], help='计算模式: image (图像vs标准色) 或 pairwise (两两比较)') subparser.add_argument('--input', required=True, help='输入文件路径 (ENVI格式LAB图像 或 CSV文件)') subparser.add_argument('--standards', required=True, help='标准颜色CSV文件路径') subparser.add_argument('--output-dir', default='./results', help='输出目录路径') subparser.add_argument('--output-file', help='输出文件名 (可选,如果不指定则使用默认名称)') subparser.add_argument('--create-heatmap', action='store_true', help='生成热图') subparser.add_argument('--create-histogram', action='store_true', help='生成直方图') subparser.add_argument('--method', default='CIEDE2000', choices=['CIE76', 'CIE94', 'CIEDE2000'], help='色差计算方法 (默认: CIEDE2000)') subparser.add_argument('--standards-range', help='标准色范围 (e.g.: "0,1,2" 或 "0-5" 或 "all")') subparser.add_argument('--reference-range', help='参考色范围 (pairwise模式)') subparser.add_argument('--target-range', help='目标色范围 (pairwise模式)') subparser.add_argument('--no-progress', action='store_true', help='禁用进度条显示') def build_config(self, args): # DeltaE直接使用args对象,不需要转换 return args class SpectralToColorArgsBuilder(ArgsBuilder): def add_common_args(self, subparser): """SpectralToColor添加通用参数(除了input,因为它使用自己的input定义)""" # 添加除了 --input 之外的通用参数 subparser.add_argument('--output-dir', default='./results', help='输出目录路径 (默认: ./results)') subparser.add_argument('--output-prefix', default='result', help='输出文件前缀 (默认: result)') subparser.add_argument('--output-file', help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)') def add_module_specific_args(self, subparser): # SpectralToColor特定的参数(与原始spectral2cie2.py一致) subparser.add_argument('--input', required=True, help='输入文件路径 (.hdr高光谱图像 或 .csv文件)') subparser.add_argument('-c', '--color_space', default='Lab', choices=['XYZ', 'xyY', 'Lab', 'LCH'], help='输出颜色空间 (默认: Lab)') subparser.add_argument('-o', '--output', required=True, help='输出文件路径') subparser.add_argument('-f', '--format', default='dat', choices=['csv', 'dat'], help='输出格式 (默认: dat)') subparser.add_argument('-i', '--illuminant', default='D65', choices=['D65', 'D50', 'A', 'F2', 'F7', 'F11'], help='标准光源 (默认: D65)') subparser.add_argument('-b', '--observer', default='2°', choices=['2°', '10°'], help='观察者 (默认: 2°)') subparser.add_argument('-w', '--wavelength_col', help='CSV文件的波长列名 (默认: wavelength)') subparser.add_argument('--no_normalize_xyz', action='store_true', help='不归一化XYZ值') subparser.add_argument('--use_numpy', action='store_true', help='使用NumPy实现而不是colour库') def build_config(self, args): # SpectralToColor直接使用args对象,但需要添加output_dir属性以保持兼容性 # 从输出路径中提取目录部分作为output_dir if hasattr(args, 'output') and args.output: import os args.output_dir = os.path.dirname(args.output) or './results' else: args.output_dir = './results' return args class XYZ2RGBArgsBuilder(ArgsBuilder): def add_common_args(self, subparser): """XYZ2RGB添加通用参数(除了input,因为它使用自己的input定义)""" # 添加除了 --input 之外的通用参数 subparser.add_argument('--output-dir', default='./results', help='输出目录路径 (默认: ./results)') subparser.add_argument('--output-prefix', default='result', help='输出文件前缀 (默认: result)') subparser.add_argument('--output-file', help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)') def add_module_specific_args(self, subparser): # XYZ2RGB特定的参数(与原始XYZ2RGB.py一致) subparser.add_argument('--input', required=True, help='输入XYZ图像文件路径') subparser.add_argument('--output', required=True, help='输出RGB图像文件路径') subparser.add_argument('--rgb-space', default='sRGB', choices=['sRGB', 'Adobe RGB (1998)', 'DCI-P3', 'ITU-R BT.709', 'ITU-R BT.2020', 'ACES2065-1', 'ACEScg', 'ProPhoto RGB', 'Apple RGB', 'PAL/SECAM', 'NTSC (1953)'], help='RGB颜色空间 (默认: sRGB)') subparser.add_argument('--gamma', default='sRGB', choices=['none', 'sRGB', 'BT.709', 'gamma_2.2', 'gamma_1.8', 'gamma_2.4', 'L*', 'BT.1886', 'ST 2084', 'HLG', 'log'], help='Gamma校正方法 (默认: sRGB)') subparser.add_argument('--output-dtype', default='uint8', choices=['uint8', 'uint10', 'uint12', 'uint16', 'int8', 'int16', 'float16', 'float32', 'float64'], help='输出数据类型 (默认: uint8)') def build_config(self, args): # XYZ2RGB直接使用args对象,但需要添加output_dir属性以保持兼容性 # 从输出路径中提取目录部分作为output_dir if hasattr(args, 'output') and args.output: import os args.output_dir = os.path.dirname(args.output) or './results' else: args.output_dir = './results' return args class SupervisedClassificationArgsBuilder(ArgsBuilder): def add_module_specific_args(self, subparser): subparser.add_argument('--xml_file', required=True, help='ENVI ROI XML文件路径') subparser.add_argument('--method', '-m', default='all', choices=['all', 'euclidean', 'cosine', 'correlation', 'information_divergence', 'jm_distance', 'sid_sa'], help='分类距离度量方法 (默认: all,使用所有方法)') subparser.add_argument('--distance-params', '-p', type=str, default=None, help='距离度量方法的超参数,JSON格式字符串,例如: {"jm_distance": {"alpha": 0.7}}') subparser.add_argument('--visualize', '-v', action='store_true', help='是否生成可视化结果') def build_config(self, args): import json # 解析距离参数 distance_params = None if getattr(args, 'distance_params', None): try: distance_params = json.loads(args.distance_params) except json.JSONDecodeError: print(f"警告: 无法解析距离参数 '{args.distance_params}',使用默认参数") distance_params = None # 返回配置字典,包含所有必要参数 config = { 'input_file': args.input, 'xml_file': args.xml_file, # 直接访问属性,不使用 getattr 'output_dir': args.output_dir, 'method': getattr(args, 'method', 'all'), 'distance_params': distance_params, 'visualize': getattr(args, 'visualize', False) } return config class PROSAILGUIArgsBuilder(ArgsBuilder): def add_common_args(self, subparser): """PROSAIL GUI 不需要通用参数,因为它是交互式的""" pass def add_module_specific_args(self, subparser): # PROSAIL GUI 通常不需要额外的参数,因为它是交互式的 # 但可以添加一些选项 subparser.add_argument('--no-gui', action='store_true', help='以命令行模式运行 (不启动GUI)') subparser.add_argument('--save-default', action='store_true', help='保存默认参数到文件') def build_config(self, args): # PROSAIL GUI 不需要复杂的配置,但需要添加一些属性以保持兼容性 # 给args对象添加output_dir属性,避免后续访问错误 if not hasattr(args, 'output_dir'): args.output_dir = './results' return args # 注册表 REGISTRY: Dict[str, Dict[str, Any]] = { # 降维分析 "dim-reduction": { "info": ModuleInfo( category="降维分析", module_path="Dimensionality_Reduction_method.dimensionality_reduction", callable_name="HyperspectralDimReduction", config_class="DimensionalityReductionConfig", description="高光谱数据降维分析" ), "args_builder": DimensionalityReductionArgsBuilder(), "call_sequence": ["__init__", "load_data", "apply_dim_reduction", "save_results"] }, # 图像分割 "segmentation": { "info": ModuleInfo( category="图像分割", module_path="segment_method.threshold_Segment", callable_name="ThresholdSegmenter", config_class="ThresholdSegmentationConfig", description="阈值图像分割" ), "args_builder": SegmentationArgsBuilder(), "call_sequence": ["__init__", "load_data", "apply_segmentation", "save_results"] }, # 边缘检测 "edge-detection": { "info": ModuleInfo( category="边缘检测", module_path="edge_detect_method.edge_detect", callable_name="EdgeDetector", config_class="EdgeDetectionConfig", description="图像边缘检测" ), "args_builder": EdgeDetectionArgsBuilder(), "call_sequence": ["__init__", "load_data", "apply_edge_detection", "save_results"] }, # 异常检测 "anomaly-detection": { "info": ModuleInfo( category="异常检测", module_path="Anomaly_method.Covariance", callable_name="CovarianceAnomalyDetector", config_class="CovarianceAnomalyConfig", description="异常检测分析" ), "args_builder": AnomalyDetectionArgsBuilder(), "call_sequence": ["__init__", "load_data", "run_analysis_from_config", "save_results"] }, # 分类分析 "classification": { "info": ModuleInfo( category="分类分析", module_path="classfication_method.classfication", callable_name="ConfigurableHyperspectralClassifier", config_class="ClassificationConfig", description="高光谱图像分类", requires_roi_file=True ), "args_builder": ClassificationArgsBuilder(), "call_sequence": ["__init__", "run_classification"] }, # 聚类分析 "clustering": { "info": ModuleInfo( category="聚类分析", module_path="cluster_method.cluster", callable_name="ClusteringManager", config_class="ClusteringConfig", description="无监督聚类分析" ), "args_builder": ClusteringArgsBuilder(), "call_sequence": ["__init__", "fit_predict", "save_clustering_results"] }, # 监督分类 "supervised-classification": { "info": ModuleInfo( category="分类分析", module_path="supervize_cluster_method.supervize_cluster", callable_name="run_hsi_classification", config_class=None, # 这个任务直接使用函数调用 description="高光谱图像监督分类", requires_roi_file=False, # 使用XML文件而不是ROI文件 requires_data_file=True ), "args_builder": SupervisedClassificationArgsBuilder(), "call_sequence": ["call_function"] # 直接调用函数 }, # 特征选择 "feature-selection": { "info": ModuleInfo( category="特征选择", module_path="Feature_Selection_method.feture_select", callable_name="SpectrumFeatureSelector", config_class="FeatureSelectionConfig", description="光谱特征选择", requires_data_file=False ), "args_builder": FeatureSelectionArgsBuilder(), "call_sequence": ["__init__", "load_csv_data", "select_features"] }, # 光谱指数 "spectral-index": { "info": ModuleInfo( category="光谱特征", module_path="spectral_feature_method.spectral_index", callable_name="HyperspectralIndexCalculator", config_class="SpectralIndexConfig", description="光谱指数计算" ), "args_builder": SpectralIndexArgsBuilder(), "call_sequence": ["__init__", "load_hyperspectral_data", "batch_calculate_and_visualize", "save_results"] }, # 数据预处理 "preprocessing": { "info": ModuleInfo( category="预处理", module_path="preprocessing_method.Preprocessing", callable_name="HyperspectralPreprocessor", config_class="PreprocessingConfig", description="数据预处理" ), "args_builder": PreprocessingArgsBuilder(), "call_sequence": ["__init__", "load_data", "preprocess", "save_data"] }, # 形状特征 "shape-features": { "info": ModuleInfo( category="空间特征", module_path="spatial_features_method.shape_feature", callable_name="analyze_shape_features", description="形状特征分析", requires_data_file=False ), "args_builder": ShapeFeatureArgsBuilder(), "call_sequence": ["call_function"] }, # GLCM纹理特征提取 "glcm": { "info": ModuleInfo( category="空间特征", module_path="spatial_features_method.glcm", callable_name="run_glcm_analysis", description="GLCM纹理特征提取", requires_data_file=True ), "args_builder": GLCMArgsBuilder(), "call_sequence": ["call_function"] }, # 颜色分析 "color-analysis": { "info": ModuleInfo( category="颜色分析", module_path="color_method.DeltaE", callable_name="DeltaEApp", description="色差计算分析" ), "args_builder": ColorAnalysisArgsBuilder(), "call_sequence": ["run"] }, # 图像滤波 "filtering": { "info": ModuleInfo( category="滤波", module_path="fliter_method.Smooth_filter", callable_name="HyperspectralImageFilter", config_class="SmoothFilterConfig", description="图像滤波处理" ), "args_builder": FilterArgsBuilder(), "call_sequence": ["process_image"] }, # 回归分析 "regression": { "info": ModuleInfo( category="回归分析", module_path="rgression_method.regression", callable_name="RegressionAnalyzer", config_class="RegressionConfig", description="回归模型分析", requires_data_file=False ), "args_builder": RegressionArgsBuilder(), "call_sequence": ["__init__", "run_analysis_from_config"] }, # 回归预测 "regression-prediction": { "info": ModuleInfo( category="回归预测", module_path="rgression_method.regression_predict", callable_name="RegressionPredictor", config_class="PredictionConfig", description="使用训练好的回归模型进行高光谱图像预测", requires_data_file=True ), "args_builder": RegressionPredictionArgsBuilder(), "call_sequence": ["__init__", "run_prediction"] }, # 色差计算 "delta-e": { "info": ModuleInfo( category="颜色分析", module_path="color_method.DeltaE", callable_name="DeltaEApp", description="色差计算分析 (CIE76, CIE94, CIEDE2000)", requires_data_file=True ), "args_builder": DeltaEArgsBuilder(), "call_sequence": ["run"] }, # 光谱到颜色转换 "spectral-to-color": { "info": ModuleInfo( category="颜色分析", module_path="color_method.spectral2cie2", callable_name="SpectralToColorConverter", description="光谱数据转换为CIE颜色空间 (XYZ, xyY, Lab, LCH)", requires_data_file=True ), "args_builder": SpectralToColorArgsBuilder(), "call_sequence": ["process_file"] }, # XYZ到RGB转换 "xyz-to-rgb": { "info": ModuleInfo( category="颜色分析", module_path="color_method.XYZ2RGB", callable_name="XYZ2RGBApp", description="XYZ颜色空间转换为RGB", requires_data_file=True ), "args_builder": XYZ2RGBArgsBuilder(), "call_sequence": ["run"] }, # PROSAIL模拟器GUI "prosail-gui": { "info": ModuleInfo( category="植被模拟", module_path="prosail_method.prosail_gui", callable_name="PROSAILSimulator", description="PROSAIL植被光谱模拟器GUI", requires_data_file=False ), "args_builder": PROSAILGUIArgsBuilder(), "call_sequence": ["run_gui"] } } def get_module_info(task_name: str) -> Optional[ModuleInfo]: """获取模块信息""" task_info = REGISTRY.get(task_name) return task_info["info"] if task_info else None def get_args_builder(task_name: str) -> Optional[ArgsBuilder]: """获取参数构建器""" task_info = REGISTRY.get(task_name) return task_info["args_builder"] if task_info else None def get_call_sequence(task_name: str) -> Optional[List[str]]: """获取调用序列""" task_info = REGISTRY.get(task_name) return task_info.get("call_sequence") def create_module_parser(subparsers, task_name: str): """为指定任务创建子解析器""" task_info = REGISTRY.get(task_name) if not task_info: return None info = task_info["info"] args_builder = task_info["args_builder"] # 创建子解析器 subparser = subparsers.add_parser( task_name, help=f"{info.category}: {info.description}", description=info.description ) # 添加通用参数 args_builder.add_common_args(subparser) # 添加模块特定参数 args_builder.add_module_specific_args(subparser) return subparser def dynamic_import_module(module_path: str): """动态导入模块""" try: return importlib.import_module(module_path) except ImportError as e: raise ImportError(f"无法导入模块 {module_path}: {e}") def get_callable_from_module(module, callable_name: str): """从模块获取可调用对象""" try: return getattr(module, callable_name) except AttributeError as e: raise AttributeError(f"模块中没有找到可调用对象 {callable_name}: {e}") def execute_task(task_name: str, args): """执行任务""" task_info = REGISTRY.get(task_name) if not task_info: raise ValueError(f"未知任务: {task_name}") info = task_info["info"] args_builder = task_info["args_builder"] call_sequence = task_info["call_sequence"] # 构建配置 config = args_builder.build_config(args) # 构建输出文件名 if hasattr(args, 'output_file') and args.output_file: output_filename = args.output_file else: # 基于输出目录和前缀构建文件名 import os output_prefix = getattr(args, 'output_prefix', 'result') output_filename = os.path.join(args.output_dir, f"{output_prefix}.dat") # 动态导入模块和可调用对象 if task_name == "anomaly-detection": # 异常检测需要根据方法动态选择类 if args.method == 'covariance': module = dynamic_import_module("Anomaly_method.Covariance") callable_obj = get_callable_from_module(module, "CovarianceAnomalyDetector") elif args.method == 'one-class-svm': module = dynamic_import_module("Anomaly_method.One_Class_SVM") callable_obj = get_callable_from_module(module, "OneClassSVMAnomalyDetector") elif args.method == 'rx': module = dynamic_import_module("Anomaly_method.RX") callable_obj = get_callable_from_module(module, "RXAnomalyDetector") elif args.method == 'squared-loss-probability': module = dynamic_import_module("Anomaly_method.squared_loss_probability") callable_obj = get_callable_from_module(module, "SquaredLossAnomalyDetector") else: raise ValueError(f"不支持的异常检测方法: {args.method}") elif task_name == "clustering": # 聚类任务直接调用run_clustering函数 from cluster_method.cluster import run_clustering # 构建参数 method_params = getattr(config, 'method_params', {}) if hasattr(config, 'method_params') else {} # 调用聚类函数 result = run_clustering( input_file=config.input_path, n_clusters=config.n_clusters, methods=[config.method], output_dir=config.output_dir, method_params=method_params, random_state=42 ) return result elif task_name == "delta-e": # 色差计算任务 module = dynamic_import_module("color_method.DeltaE") callable_obj = get_callable_from_module(module, "DeltaEApp") app_instance = callable_obj() result = app_instance.run(config) # 使用build_config返回的对象 return result elif task_name == "spectral-to-color": # 光谱到颜色转换任务 - 直接调用main函数 import sys from color_method.spectral2cie2 import main as spectral_main # 保存原始sys.argv original_argv = sys.argv # 构造命令行参数 sys.argv = ['spectral2cie2.py', '--input', config.input, '-o', config.output, '-c', getattr(config, 'color_space', 'Lab'), '-i', getattr(config, 'illuminant', 'D65'), '-b', getattr(config, 'observer', '2°')] if hasattr(config, 'wavelength_col') and config.wavelength_col: sys.argv.extend(['-w', config.wavelength_col]) if getattr(config, 'no_normalize_xyz', False): sys.argv.append('--no_normalize_xyz') if getattr(config, 'use_numpy', False): sys.argv.append('--use_numpy') try: result = spectral_main() finally: # 恢复原始sys.argv sys.argv = original_argv return result elif task_name == "xyz-to-rgb": # XYZ到RGB转换任务 module = dynamic_import_module("color_method.XYZ2RGB") callable_obj = get_callable_from_module(module, "XYZ2RGBApp") app_instance = callable_obj() result = app_instance.run(config) # 使用build_config返回的对象 return result elif task_name == "preprocessing": # 预处理任务 - 需要特殊处理load_data调用 from preprocessing_method.Preprocessing import HyperspectralPreprocessor instance = HyperspectralPreprocessor(config) # 确定输出文件路径(基于输入文件类型和方法名称) import os from pathlib import Path input_path = Path(config.input_path) method_name = config.method.lower() # 方法名称转为小写用于文件名 if input_path.suffix.lower() == '.csv': # CSV输入,输出CSV output_path = os.path.join(config.output_dir, f"{method_name}_{input_path.stem}.csv") else: # ENVI输入或其他格式,输出ENVI output_path = os.path.join(config.output_dir, f"{method_name}_{input_path.stem}.dat") # 调用load_data时传递正确的参数 instance.load_data(config.input_path, config.spectral_start_index) # 执行预处理 result = instance.preprocess(config.method, output_path) return result elif task_name == "filtering": # 滤波任务 - 根据滤波类型选择不同的处理方式 if hasattr(config, 'filter_type'): # Smooth_filter类型 from fliter_method.Smooth_filter import HyperspectralImageFilter filter_obj = HyperspectralImageFilter() # 构建参数字典 kwargs = { 'kernel_size': config.kernel_size, } if config.filter_type == 'gaussian': kwargs['sigma'] = config.sigma elif config.filter_type == 'bilateral': kwargs['sigma_color'] = config.sigma_color kwargs['sigma_space'] = config.sigma_space # 生成输出路径 output_path = os.path.join(config.output_dir, f"{config.filter_type}_filtered") result = filter_obj.process_image( input_path=config.input_path, output_path=output_path, filter_type=config.filter_type, band_index=config.band_index, **kwargs ) return result elif task_name == "prosail-gui": # PROSAIL GUI 任务 - 启动GUI应用 from prosail_method.prosail_gui import PROSAILSimulator, QApplication import sys if getattr(config, 'no_gui', False): # 命令行模式 - 这里可以添加命令行版本的PROSAIL功能 print("PROSAIL GUI 命令行模式暂未实现,请使用 GUI 模式:") print("python main.py prosail-gui") return None else: # GUI模式 app = QApplication(sys.argv) app.setStyle('Fusion') # Modern style window = PROSAILSimulator() window.show() result = app.exec_() return result elif task_name == "morphological_fliter": # Morphological_filter类型 from fliter_method.morphological_fliter import HyperspectralMorphologyProcessor processor = HyperspectralMorphologyProcessor() # 加载图像 data, header = processor.load_hyperspectral_image(config.input_path, config.format_type) # 提取波段 band_data = processor.extract_band(config.band_index) # 应用形态学操作 result = processor.apply_morphology_operation( band_data, config.operation, config.se_shape, config.se_size ) # 保存结果 output_path = os.path.join(config.output_dir, f"{config.operation}_result") processor.save_as_envi(result, output_path, f"{config.operation.capitalize()} 处理结果 - 波段 {config.band_index}") return result else: module = dynamic_import_module(info.module_path) callable_obj = get_callable_from_module(module, info.callable_name) # 检查是否是批量处理(针对降维分析) is_batch = (task_name == "dim-reduction" and hasattr(args, 'batch_methods') and args.batch_methods is not None) if is_batch: # 批量处理 instance = callable_obj(config) instance.load_data() result = instance.batch_process(args.batch_methods, args.batch_components, args.output_dir) return result # 执行调用序列 result = None instance = None for step in call_sequence: if step == "__init__": instance = callable_obj(config) elif step == "call_function": # 直接调用函数,传递config对象作为参数 result = callable_obj(config) elif hasattr(instance, step): method = getattr(instance, step) if step in ["load_data", "read_hyperspectral_data", "load_hyperspectral_data", "load_csv_data"]: method() elif step in ["apply_dim_reduction", "apply_segmentation", "apply_edge_detection", "detect_anomalies", "select_features", "batch_calculate_and_visualize", "preprocess"]: result = method() elif step == "run": # 对于某些应用类,直接调用run方法并传递参数 result = method(config) elif step == "save_results": # 根据任务类型正确调用 save_results if task_name == "dim-reduction": # result = (reduced_data, band_names) method(result[0], result[1], output_filename, args.method) elif task_name == "segmentation": # result = (segmented_data, threshold) method(result[0], result[1], output_filename, args.method) elif task_name == "edge-detection": # result = (edge_data, info) method(result[0], result[1], output_filename, args.method) elif task_name == "spectral-index": # result = (results, fig), save_results 期望 (results, output_prefix) method(result[0], args.output_prefix) elif task_name == "anomaly-detection": # save_results 只接受可选的 output_path 参数 method(output_filename) else: # 默认调用方式 method(result, output_filename, task_name) elif step == "run_analysis_from_config": result = method() elif step == "calculate_differences": result = method() elif step == "apply_filter": result = method(args.filter_type if hasattr(args, 'filter_type') else 'mean') elif step == "save_envi": method(output_filename, result) else: # 其他方法直接调用 if step == "parse_roi_xml" and hasattr(args, 'roi_file'): method(args.roi_file) elif step == "extract_training_samples": result = method() elif step == "train_model": method(result[0], result[1]) elif step == "predict_image": method() elif step == "run_classification": method() else: method() return result