2318 lines
110 KiB
Python
2318 lines
110 KiB
Python
"""
|
||
高光谱分析工具包统一注册表
|
||
提供所有功能模块的注册信息和参数构建器
|
||
"""
|
||
|
||
import importlib
|
||
from typing import Dict, Any, Callable, List, Optional, Union
|
||
from dataclasses import dataclass
|
||
import argparse
|
||
|
||
|
||
@dataclass
|
||
class ModuleInfo:
|
||
"""模块信息类"""
|
||
category: str
|
||
module_path: str
|
||
callable_name: str
|
||
config_class: Optional[str] = None
|
||
description: str = ""
|
||
requires_data_file: bool = True
|
||
requires_roi_file: bool = False
|
||
output_types: List[str] = None
|
||
|
||
def __post_init__(self):
|
||
if self.output_types is None:
|
||
self.output_types = ["file", "summary"]
|
||
|
||
|
||
class ArgsBuilder:
|
||
"""参数构建器基类"""
|
||
|
||
def add_common_args(self, subparser: argparse.ArgumentParser):
|
||
"""添加通用参数"""
|
||
subparser.add_argument('--input', '-i', required=True,
|
||
help='输入文件路径 (ENVI格式: .hdr/.dat 或 CSV)')
|
||
subparser.add_argument('--output-dir', '-o', default='./results',
|
||
help='输出目录路径 (默认: ./results)')
|
||
subparser.add_argument('--output-prefix', default='result',
|
||
help='输出文件前缀 (默认: result)')
|
||
subparser.add_argument('--output-file', '-f',
|
||
help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)')
|
||
|
||
def add_module_specific_args(self, subparser: argparse.ArgumentParser):
|
||
"""子类实现:添加模块特定参数"""
|
||
pass
|
||
|
||
def build_config(self, args) -> Any:
|
||
"""子类实现:构建配置对象"""
|
||
pass
|
||
|
||
|
||
# 各模块的参数构建器
|
||
class DimensionalityReductionArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
# CSV文件参数
|
||
subparser.add_argument('--label-col', '-l', help='CSV文件的标签列名或索引')
|
||
subparser.add_argument('--spectral-start', '-s', help='CSV文件的谱段起始列名或索引')
|
||
subparser.add_argument('--spectral-end', '-e', help='CSV文件的谱段结束列名或索引')
|
||
|
||
subparser.add_argument('--method', '-m', default='pca',
|
||
choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'],
|
||
help='降维方法')
|
||
subparser.add_argument('--n-components', '-n', type=int, default=3,
|
||
help='降维后的维度数')
|
||
subparser.add_argument('--batch-methods', nargs='+',
|
||
choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'],
|
||
help='批量处理的降维方法列表')
|
||
subparser.add_argument('--batch-components', nargs='+', type=int,
|
||
help='对应批量方法的维度数列表')
|
||
|
||
# 数据预处理参数
|
||
subparser.add_argument('--use-standardization', action='store_true', default=True,
|
||
help='是否使用数据标准化')
|
||
subparser.add_argument('--generate-plots', action='store_true', default=True,
|
||
help='是否生成可视化图表')
|
||
subparser.add_argument('--color-by-label', action='store_true', default=True,
|
||
help='是否按标签着色散点图')
|
||
|
||
# PCA专用参数
|
||
subparser.add_argument('--pca-whiten', action='store_true',
|
||
help='是否白化数据 (PCA方法)')
|
||
subparser.add_argument('--pca-svd-solver', choices=['auto', 'full', 'arpack', 'randomized'], default='auto',
|
||
help='SVD求解器 (PCA方法)')
|
||
|
||
# ICA专用参数
|
||
subparser.add_argument('--ica-algorithm', choices=['parallel', 'deflation'], default='parallel',
|
||
help='ICA算法 (ICA方法)')
|
||
subparser.add_argument('--ica-fun', choices=['logcosh', 'exp', 'cube'], default='logcosh',
|
||
help='非线性函数 (ICA方法)')
|
||
subparser.add_argument('--ica-max-iter', type=int, default=200,
|
||
help='最大迭代次数 (ICA方法)')
|
||
|
||
# FA专用参数
|
||
subparser.add_argument('--fa-max-iter', type=int, default=1000,
|
||
help='最大迭代次数 (FA方法)')
|
||
subparser.add_argument('--fa-tol', type=float, default=1e-2,
|
||
help='收敛容差 (FA方法)')
|
||
|
||
# MDS专用参数
|
||
subparser.add_argument('--mds-metric', action='store_true', default=True,
|
||
help='是否使用度量MDS (MDS方法)')
|
||
subparser.add_argument('--mds-max-iter', type=int, default=300,
|
||
help='最大迭代次数 (MDS方法)')
|
||
subparser.add_argument('--mds-n-init', type=int, default=4,
|
||
help='初始化次数 (MDS方法)')
|
||
|
||
# Isomap专用参数
|
||
subparser.add_argument('--isomap-n-neighbors', type=int, default=5,
|
||
help='邻居数量 (Isomap方法)')
|
||
subparser.add_argument('--isomap-eigen-solver', choices=['auto', 'arpack', 'dense'], default='auto',
|
||
help='特征值求解器 (Isomap方法)')
|
||
|
||
# LLE专用参数
|
||
subparser.add_argument('--lle-n-neighbors', type=int, default=5,
|
||
help='邻居数量 (LLE方法)')
|
||
subparser.add_argument('--lle-method', choices=['standard', 'hessian', 'modified', 'ltsa'], default='standard',
|
||
help='LLE方法类型 (LLE方法)')
|
||
|
||
# t-SNE专用参数
|
||
subparser.add_argument('--tsne-perplexity', type=float, default=30.0,
|
||
help='困惑度 (t-SNE方法)')
|
||
subparser.add_argument('--tsne-early-exaggeration', type=float, default=12.0,
|
||
help='早期夸张系数 (t-SNE方法)')
|
||
subparser.add_argument('--tsne-learning-rate', type=float, default=200.0,
|
||
help='学习率 (t-SNE方法)')
|
||
subparser.add_argument('--tsne-max-iter', type=int, default=1000,
|
||
help='最大迭代次数 (t-SNE方法)')
|
||
|
||
def build_config(self, args):
|
||
from Dimensionality_Reduction_method.dimensionality_reduction import DimensionalityReductionConfig
|
||
|
||
# 检查是否是批量处理
|
||
is_batch = hasattr(args, 'batch_methods') and args.batch_methods is not None
|
||
|
||
# 构建方法参数
|
||
method_params = {}
|
||
|
||
if not is_batch and hasattr(args, 'method'):
|
||
if args.method == 'pca':
|
||
if hasattr(args, 'pca_whiten'):
|
||
method_params['whiten'] = args.pca_whiten
|
||
if hasattr(args, 'pca_svd_solver'):
|
||
method_params['svd_solver'] = args.pca_svd_solver
|
||
elif args.method == 'ica':
|
||
if hasattr(args, 'ica_algorithm'):
|
||
method_params['algorithm'] = args.ica_algorithm
|
||
if hasattr(args, 'ica_fun'):
|
||
method_params['fun'] = args.ica_fun
|
||
if hasattr(args, 'ica_max_iter'):
|
||
method_params['max_iter'] = args.ica_max_iter
|
||
elif args.method == 'fa':
|
||
if hasattr(args, 'fa_max_iter'):
|
||
method_params['max_iter'] = args.fa_max_iter
|
||
if hasattr(args, 'fa_tol'):
|
||
method_params['tol'] = args.fa_tol
|
||
elif args.method == 'mds':
|
||
if hasattr(args, 'mds_metric'):
|
||
method_params['metric'] = args.mds_metric
|
||
if hasattr(args, 'mds_max_iter'):
|
||
method_params['max_iter'] = args.mds_max_iter
|
||
if hasattr(args, 'mds_n_init'):
|
||
method_params['n_init'] = args.mds_n_init
|
||
elif args.method == 'isomap':
|
||
if hasattr(args, 'isomap_n_neighbors'):
|
||
method_params['n_neighbors'] = args.isomap_n_neighbors
|
||
if hasattr(args, 'isomap_eigen_solver'):
|
||
method_params['eigen_solver'] = args.isomap_eigen_solver
|
||
elif args.method == 'lle':
|
||
if hasattr(args, 'lle_n_neighbors'):
|
||
method_params['n_neighbors'] = args.lle_n_neighbors
|
||
if hasattr(args, 'lle_method'):
|
||
method_params['method'] = args.lle_method
|
||
elif args.method == 't-sne':
|
||
if hasattr(args, 'tsne_perplexity'):
|
||
method_params['perplexity'] = args.tsne_perplexity
|
||
if hasattr(args, 'tsne_early_exaggeration'):
|
||
method_params['early_exaggeration'] = args.tsne_early_exaggeration
|
||
if hasattr(args, 'tsne_learning_rate'):
|
||
method_params['learning_rate'] = args.tsne_learning_rate
|
||
if hasattr(args, 'tsne_max_iter'):
|
||
method_params['max_iter'] = args.tsne_max_iter
|
||
|
||
# 构建配置参数
|
||
config_kwargs = {
|
||
'input_path': args.input,
|
||
'output_dir': args.output_dir,
|
||
'label_col': getattr(args, 'label_col', None),
|
||
'spectral_start': getattr(args, 'spectral_start', None),
|
||
'spectral_end': getattr(args, 'spectral_end', None),
|
||
'method_params': method_params,
|
||
'use_standardization': getattr(args, 'use_standardization', True),
|
||
'generate_plots': getattr(args, 'generate_plots', True),
|
||
'color_by_label': getattr(args, 'color_by_label', True),
|
||
'save_plots': getattr(args, 'save_plots', True)
|
||
}
|
||
|
||
if is_batch:
|
||
# 批量处理
|
||
if not hasattr(args, 'batch_components') or not args.batch_components:
|
||
batch_components = [3] * len(args.batch_methods)
|
||
else:
|
||
batch_components = args.batch_components
|
||
if len(batch_components) != len(args.batch_methods):
|
||
raise ValueError(f"批量方法数量 ({len(args.batch_methods)}) 与维度数量 ({len(batch_components)}) 不匹配")
|
||
|
||
config_kwargs.update({
|
||
'batch_methods': args.batch_methods,
|
||
'batch_components': batch_components
|
||
})
|
||
else:
|
||
# 单方法处理
|
||
config_kwargs.update({
|
||
'method': args.method,
|
||
'n_components': args.n_components
|
||
})
|
||
|
||
return DimensionalityReductionConfig(**config_kwargs)
|
||
|
||
|
||
class SegmentationArgsBuilder(ArgsBuilder):
|
||
def add_common_args(self, subparser):
|
||
"""重写通用参数,设置分割专用的默认输出目录"""
|
||
subparser.add_argument('--input', '-i', required=True,
|
||
help='输入文件路径 (ENVI格式: .hdr/.dat 或图像文件)')
|
||
subparser.add_argument('--output-dir', '-o', default='./segmentation_results',
|
||
help='输出目录路径 (默认: ./segmentation_results)')
|
||
subparser.add_argument('--output-prefix', default='result',
|
||
help='输出文件前缀 (默认: result)')
|
||
|
||
def add_module_specific_args(self, subparser):
|
||
# 基本分割参数
|
||
subparser.add_argument('--method', '-m', default='otsu',
|
||
choices=['fixed', 'bimodal', 'iterative', 'otsu', 'isodata', 'adaptive'],
|
||
help='分割方法 (默认: otsu)')
|
||
subparser.add_argument('--band-index', '-b', type=int, default=0,
|
||
help='使用的波段索引 (默认: 0)')
|
||
|
||
# 阈值参数
|
||
subparser.add_argument('--threshold', '-t', type=float,
|
||
help='固定阈值 (仅用于fixed方法)')
|
||
|
||
# 批量处理参数
|
||
subparser.add_argument('--batch-methods', nargs='+',
|
||
choices=['fixed', 'bimodal', 'iterative', 'otsu', 'isodata', 'adaptive'],
|
||
help='批量处理的分割方法列表')
|
||
subparser.add_argument('--batch-bands', type=int, nargs='+',
|
||
help='批量处理的波段索引列表')
|
||
|
||
# 自适应阈值参数
|
||
subparser.add_argument('--adaptive-block-size', type=int, default=11,
|
||
help='自适应阈值块大小,必须为奇数 (默认: 11)')
|
||
subparser.add_argument('--adaptive-c', type=float, default=2.0,
|
||
help='自适应阈值常数 (默认: 2.0)')
|
||
subparser.add_argument('--adaptive-method', choices=['gaussian', 'mean'], default='gaussian',
|
||
help='自适应阈值方法 (默认: gaussian)')
|
||
|
||
# 迭代法参数
|
||
subparser.add_argument('--max-iterations', type=int, default=100,
|
||
help='迭代法最大迭代次数 (默认: 100)')
|
||
subparser.add_argument('--convergence-threshold', type=float, default=0.01,
|
||
help='迭代法收敛阈值 (默认: 0.01)')
|
||
|
||
# 直方图双峰法参数
|
||
subparser.add_argument('--histogram-bins', type=int, default=256,
|
||
help='直方图bin数量 (默认: 256)')
|
||
subparser.add_argument('--peak-min-distance', type=int, default=10,
|
||
help='峰值最小距离 (默认: 10)')
|
||
|
||
# Otsu参数
|
||
subparser.add_argument('--otsu-normalize', action='store_true', default=True,
|
||
help='Otsu方法是否归一化图像到0-255范围 (默认: True)')
|
||
|
||
def build_config(self, args):
|
||
from segment_method.threshold_Segment import ThresholdSegmentationConfig
|
||
|
||
# 处理批量参数
|
||
batch_methods = getattr(args, 'batch_methods', None)
|
||
batch_bands = getattr(args, 'batch_bands', None)
|
||
|
||
if batch_methods is None:
|
||
batch_methods = [args.method]
|
||
if batch_bands is None:
|
||
batch_bands = [args.band_index]
|
||
|
||
return ThresholdSegmentationConfig(
|
||
input_path=args.input,
|
||
output_dir=args.output_dir,
|
||
method=args.method,
|
||
band_index=args.band_index,
|
||
threshold_value=getattr(args, 'threshold', None),
|
||
batch_methods=batch_methods,
|
||
batch_bands=batch_bands,
|
||
adaptive_block_size=getattr(args, 'adaptive_block_size', 11),
|
||
adaptive_c=getattr(args, 'adaptive_c', 2.0),
|
||
adaptive_method=getattr(args, 'adaptive_method', 'gaussian'),
|
||
max_iterations=getattr(args, 'max_iterations', 100),
|
||
convergence_threshold=getattr(args, 'convergence_threshold', 0.01),
|
||
histogram_bins=getattr(args, 'histogram_bins', 256),
|
||
peak_min_distance=getattr(args, 'peak_min_distance', 10),
|
||
otsu_normalize=getattr(args, 'otsu_normalize', True)
|
||
)
|
||
|
||
|
||
class EdgeDetectionArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
# 基本参数
|
||
subparser.add_argument('--method', '-m', default='canny',
|
||
choices=['sobel', 'scharr', 'laplacian', 'log', 'canny'],
|
||
help='边缘检测方法')
|
||
subparser.add_argument('--edge-band-index', type=int, default=0,
|
||
help='使用的波段索引')
|
||
subparser.add_argument('--batch', action='store_true',
|
||
help='批量运行多种方法')
|
||
subparser.add_argument('--batch-bands', nargs='+', type=int,
|
||
help='批量处理的波段索引列表')
|
||
|
||
# Sobel和Scharr参数
|
||
subparser.add_argument('--sobel-dx', type=int, default=1, choices=[0, 1],
|
||
help='Sobel x方向导数阶数,必须与sobel-dy之和等于1 (默认: 1)')
|
||
subparser.add_argument('--sobel-dy', type=int, default=0, choices=[0, 1],
|
||
help='Sobel y方向导数阶数,必须与sobel-dx之和等于1 (默认: 0)')
|
||
subparser.add_argument('--sobel-ksize', type=int, default=-1,
|
||
help='Sobel核大小,-1表示3x3 (默认: -1)')
|
||
|
||
# Laplacian参数
|
||
subparser.add_argument('--laplacian-ksize', type=int, default=1,
|
||
help='Laplacian核大小 (默认: 1)')
|
||
subparser.add_argument('--laplacian-scale', type=float, default=1.0,
|
||
help='Laplacian尺度参数 (默认: 1.0)')
|
||
subparser.add_argument('--laplacian-delta', type=float, default=0.0,
|
||
help='Laplacian delta参数 (默认: 0.0)')
|
||
|
||
# LoG参数
|
||
subparser.add_argument('--log-sigma', type=float, default=1.0,
|
||
help='LoG高斯标准差 (默认: 1.0)')
|
||
subparser.add_argument('--log-threshold', type=float, default=0.0,
|
||
help='LoG阈值,0表示自动阈值 (默认: 0.0)')
|
||
subparser.add_argument('--log-mode', default='reflect',
|
||
choices=['reflect', 'constant', 'nearest', 'mirror', 'wrap'],
|
||
help='LoG边界模式 (默认: reflect)')
|
||
|
||
# 梯度算法阈值参数
|
||
subparser.add_argument('--gradient-threshold', type=float, default=0.0,
|
||
help='梯度阈值 (0-255),0表示不使用阈值 (默认: 0.0)')
|
||
subparser.add_argument('--gradient-use-otsu', action='store_true',
|
||
help='对梯度图使用Otsu自动阈值')
|
||
|
||
# Canny参数
|
||
subparser.add_argument('--canny-min-threshold', type=int, default=100,
|
||
help='Canny最小阈值 (默认: 100)')
|
||
subparser.add_argument('--canny-max-threshold', type=int, default=200,
|
||
help='Canny最大阈值 (默认: 200)')
|
||
subparser.add_argument('--canny-aperture-size', type=int, default=3, choices=[3, 5, 7],
|
||
help='Canny Sobel算子孔径大小 (默认: 3)')
|
||
subparser.add_argument('--canny-l2gradient', action='store_true',
|
||
help='Canny使用L2范数计算梯度')
|
||
|
||
# 通用参数
|
||
subparser.add_argument('--normalize-output', action='store_true',
|
||
help='将输出归一化到0-255范围')
|
||
subparser.add_argument('--use-blur', action='store_true',
|
||
help='预先进行高斯模糊')
|
||
subparser.add_argument('--blur-kernel-size', type=int, default=3,
|
||
help='模糊核大小,必须是奇数 (默认: 3)')
|
||
subparser.add_argument('--blur-sigma', type=float, default=1.0,
|
||
help='模糊标准差 (默认: 1.0)')
|
||
|
||
def build_config(self, args):
|
||
from edge_detect_method.edge_detect import EdgeDetectionConfig
|
||
|
||
# 确定输出路径
|
||
if hasattr(args, 'output_file') and args.output_file:
|
||
output_path = args.output_file
|
||
output_dir = None
|
||
else:
|
||
output_path = None
|
||
output_dir = args.output_dir
|
||
|
||
# 构建批量处理参数
|
||
batch_methods = [args.method]
|
||
batch_bands = [args.edge_band_index]
|
||
|
||
if hasattr(args, 'batch') and args.batch:
|
||
# 如果是批量模式,使用默认的多种方法
|
||
batch_methods = ['sobel', 'scharr', 'laplacian', 'log', 'canny']
|
||
if hasattr(args, 'batch_bands') and args.batch_bands:
|
||
if len(args.batch_bands) == len(batch_methods):
|
||
batch_bands = args.batch_bands
|
||
else:
|
||
batch_bands = args.batch_bands * (len(batch_methods) // len(args.batch_bands) + 1)
|
||
batch_bands = batch_bands[:len(batch_methods)]
|
||
else:
|
||
batch_bands = [args.edge_band_index] * len(batch_methods)
|
||
|
||
return EdgeDetectionConfig(
|
||
input_path=args.input,
|
||
method=args.method,
|
||
band_index=args.edge_band_index,
|
||
output_path=output_path,
|
||
output_dir=output_dir,
|
||
batch_methods=batch_methods,
|
||
batch_bands=batch_bands,
|
||
# Sobel参数
|
||
sobel_dx=getattr(args, 'sobel_dx', 1),
|
||
sobel_dy=getattr(args, 'sobel_dy', 1),
|
||
sobel_ksize=getattr(args, 'sobel_ksize', -1),
|
||
# Laplacian参数
|
||
laplacian_ksize=getattr(args, 'laplacian_ksize', 1),
|
||
laplacian_scale=getattr(args, 'laplacian_scale', 1.0),
|
||
laplacian_delta=getattr(args, 'laplacian_delta', 0.0),
|
||
# LoG参数
|
||
log_sigma=getattr(args, 'log_sigma', 1.0),
|
||
log_threshold=getattr(args, 'log_threshold', 0.0),
|
||
log_mode=getattr(args, 'log_mode', 'reflect'),
|
||
# 梯度参数
|
||
gradient_threshold=getattr(args, 'gradient_threshold', 0.0),
|
||
gradient_use_otsu=getattr(args, 'gradient_use_otsu', False),
|
||
# Canny参数
|
||
canny_min_threshold=getattr(args, 'canny_min_threshold', 100),
|
||
canny_max_threshold=getattr(args, 'canny_max_threshold', 200),
|
||
canny_aperture_size=getattr(args, 'canny_aperture_size', 3),
|
||
canny_l2gradient=getattr(args, 'canny_l2gradient', False),
|
||
# 通用参数
|
||
normalize_output=getattr(args, 'normalize_output', False),
|
||
use_blur=getattr(args, 'use_blur', False),
|
||
blur_kernel_size=getattr(args, 'blur_kernel_size', 3),
|
||
blur_sigma=getattr(args, 'blur_sigma', 1.0)
|
||
)
|
||
|
||
|
||
class AnomalyDetectionArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--method', '-m', default='covariance',
|
||
choices=['covariance', 'one-class-svm', 'rx', 'squared-loss-probability'],
|
||
help='异常检测方法')
|
||
|
||
# 通用参数(所有方法都可能用到)
|
||
subparser.add_argument('--contamination', '-c', type=float, default=0.1,
|
||
help='异常样本比例 (用于自动确定阈值)')
|
||
|
||
# RX方法专用参数
|
||
subparser.add_argument('--background', help='背景数据路径 (RX方法custom模式时必需)')
|
||
subparser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||
subparser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||
subparser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||
subparser.add_argument('--background-model', choices=['global', 'local_row', 'local_square', 'custom'],
|
||
default='global', help='背景模型类型 (RX方法,默认global)')
|
||
subparser.add_argument('--window-size', type=int, default=5,
|
||
help='局部窗口大小 (RX方法,默认5)')
|
||
subparser.add_argument('--row-window', type=int, default=3,
|
||
help='行窗口大小 (RX方法,默认3)')
|
||
subparser.add_argument('--rx-threshold', type=float,
|
||
help='异常检测阈值 (RX方法,不指定则只输出距离图)')
|
||
subparser.add_argument('--robust-covariance', action='store_true',
|
||
help='使用稳健协方差估计 (RX方法)')
|
||
|
||
# One-Class SVM专用参数
|
||
subparser.add_argument('--ocsvm-kernel', choices=['linear', 'rbf', 'poly'], default='rbf',
|
||
help='SVM核函数 (One-Class SVM方法,默认rbf)')
|
||
subparser.add_argument('--ocsvm-nu', type=float, default=0.1,
|
||
help='训练误差比例 (One-Class SVM方法,默认0.1)')
|
||
subparser.add_argument('--ocsvm-gamma', type=float,
|
||
help='核函数参数gamma (One-Class SVM方法,不指定则自动选择)')
|
||
subparser.add_argument('--ocsvm-degree', type=int, default=3,
|
||
help='多项式核的阶数 (One-Class SVM方法,默认3)')
|
||
subparser.add_argument('--ocsvm-coef0', type=float, default=0.0,
|
||
help='核函数常数项 (One-Class SVM方法,默认0.0)')
|
||
subparser.add_argument('--ocsvm-tol', type=float, default=1e-3,
|
||
help='停止准则公差 (One-Class SVM方法,默认1e-3)')
|
||
subparser.add_argument('--ocsvm-shrinking', action='store_true', default=True,
|
||
help='是否使用shrinking heuristic (One-Class SVM方法,默认True)')
|
||
subparser.add_argument('--ocsvm-cache-size', type=float, default=200,
|
||
help='核缓存大小(MB) (One-Class SVM方法,默认200)')
|
||
subparser.add_argument('--ocsvm-max-iter', type=int, default=-1,
|
||
help='最大迭代次数 (One-Class SVM方法,默认-1表示无限制)')
|
||
subparser.add_argument('--ocsvm-use-scaling', action='store_true', default=True,
|
||
help='是否使用标准化 (One-Class SVM方法,默认True)')
|
||
subparser.add_argument('--ocsvm-use-pca', action='store_true', default=True,
|
||
help='是否使用PCA降维 (One-Class SVM方法,默认True)')
|
||
subparser.add_argument('--ocsvm-n-components', type=int,
|
||
help='PCA降维维度 (One-Class SVM方法,默认自动选择)')
|
||
subparser.add_argument('--ocsvm-use-grid-search', action='store_true',
|
||
help='是否使用网格搜索调优参数 (One-Class SVM方法,默认False)')
|
||
subparser.add_argument('--ocsvm-cv-folds', type=int, default=3,
|
||
help='交叉验证折数 (One-Class SVM方法,默认3)')
|
||
|
||
# Squared Loss Probability专用参数
|
||
subparser.add_argument('--reconstruction-method', choices=['linear', 'pca'], default='linear',
|
||
help='重构方法 (Squared Loss方法,默认linear)')
|
||
subparser.add_argument('--slp-n-components', type=int,
|
||
help='PCA降维维度 (Squared Loss方法使用PCA时必需)')
|
||
subparser.add_argument('--target-band', type=int,
|
||
help='目标波段索引 (Squared Loss方法,默认使用最后一个波段)')
|
||
subparser.add_argument('--slp-threshold', type=float, default=0.8,
|
||
help='分类阈值 (Squared Loss方法,默认0.8)')
|
||
subparser.add_argument('--probability-mode', action='store_true', default=True,
|
||
help='使用概率模式 (Squared Loss方法,默认True)')
|
||
subparser.add_argument('--slp-use-scaling', action='store_true', default=True,
|
||
help='是否使用标准化 (Squared Loss方法,默认True)')
|
||
subparser.add_argument('--scaler-type', choices=['standard', 'robust'], default='robust',
|
||
help='标准化类型 (Squared Loss方法,默认robust)')
|
||
subparser.add_argument('--fit-intercept', action='store_true', default=True,
|
||
help='是否拟合截距 (Squared Loss方法,默认True)')
|
||
|
||
def build_config(self, args):
|
||
if args.method == 'covariance':
|
||
from Anomaly_method.Covariance import CovarianceAnomalyConfig
|
||
|
||
config_kwargs = {
|
||
'input_path': args.input,
|
||
'output_dir': args.output_dir,
|
||
'contamination': getattr(args, 'contamination', 0.1)
|
||
}
|
||
return CovarianceAnomalyConfig(**config_kwargs)
|
||
|
||
elif args.method == 'one-class-svm':
|
||
from Anomaly_method.One_Class_SVM import OneClassSVMConfig
|
||
|
||
# 获取默认配置的值(不创建完整的配置对象以避免验证)
|
||
config_kwargs = {
|
||
'input_path': args.input,
|
||
'output_dir': args.output_dir,
|
||
'kernel': getattr(args, 'ocsvm_kernel', 'rbf'),
|
||
'nu': getattr(args, 'ocsvm_nu', 0.1),
|
||
'gamma': getattr(args, 'ocsvm_gamma', None),
|
||
'degree': getattr(args, 'ocsvm_degree', 3),
|
||
'coef0': getattr(args, 'ocsvm_coef0', 0.0),
|
||
'tol': getattr(args, 'ocsvm_tol', 1e-3),
|
||
'shrinking': getattr(args, 'ocsvm_shrinking', True),
|
||
'cache_size': getattr(args, 'ocsvm_cache_size', 200.0),
|
||
'max_iter': getattr(args, 'ocsvm_max_iter', -1),
|
||
'use_scaling': getattr(args, 'ocsvm_use_scaling', True),
|
||
'use_pca': getattr(args, 'ocsvm_use_pca', True),
|
||
'n_components': getattr(args, 'ocsvm_n_components', None),
|
||
'use_grid_search': getattr(args, 'ocsvm_use_grid_search', False),
|
||
'cv_folds': getattr(args, 'ocsvm_cv_folds', 3)
|
||
}
|
||
|
||
return OneClassSVMConfig(**config_kwargs)
|
||
|
||
elif args.method == 'rx':
|
||
from Anomaly_method.RX import RXConfig
|
||
|
||
config_kwargs = {
|
||
'input_path': args.input,
|
||
'background_path': getattr(args, 'background', None),
|
||
'label_col': getattr(args, 'label_col', None),
|
||
'spectral_start': getattr(args, 'spectral_start', None),
|
||
'spectral_end': getattr(args, 'spectral_end', None),
|
||
'output_dir': args.output_dir,
|
||
'background_model': getattr(args, 'background_model', 'global'),
|
||
'window_size': getattr(args, 'window_size', 5),
|
||
'row_window': getattr(args, 'row_window', 3),
|
||
'threshold': getattr(args, 'rx_threshold', None),
|
||
'contamination': getattr(args, 'contamination', 0.1),
|
||
'use_robust_covariance': getattr(args, 'robust_covariance', False)
|
||
}
|
||
|
||
return RXConfig(**config_kwargs)
|
||
|
||
elif args.method == 'squared-loss-probability':
|
||
from Anomaly_method.squared_loss_probability import SquaredLossAnomalyConfig
|
||
|
||
config_kwargs = {
|
||
'input_path': args.input,
|
||
'output_dir': args.output_dir,
|
||
'reconstruction_method': getattr(args, 'reconstruction_method', 'linear'),
|
||
'n_components': getattr(args, 'slp_n_components', None),
|
||
'target_band': getattr(args, 'target_band', None),
|
||
'probability_flag': getattr(args, 'probability_mode', True),
|
||
'threshold': getattr(args, 'slp_threshold', 0.8),
|
||
'use_scaling': getattr(args, 'slp_use_scaling', True),
|
||
'scaler_type': getattr(args, 'scaler_type', 'robust'),
|
||
'fit_intercept': getattr(args, 'fit_intercept', True)
|
||
}
|
||
|
||
return SquaredLossAnomalyConfig(**config_kwargs)
|
||
|
||
else:
|
||
raise ValueError(f"不支持的异常检测方法: {args.method}")
|
||
|
||
|
||
class ClassificationArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--method', '-m', default='svm',
|
||
choices=['svm', 'random_forest', 'knn', 'logistic_regression', 'linear_discriminant', 'quadratic_discriminant', 'plsda', 'xgboost', 'lightgbm', 'catboost'],
|
||
help='分类方法')
|
||
subparser.add_argument('--roi-file', '-r', required=True,
|
||
help='ROI标注文件路径 (.xml)')
|
||
subparser.add_argument('--test-size', '-ts', type=float, default=0.3,
|
||
help='测试集比例')
|
||
|
||
# 数据预处理参数
|
||
subparser.add_argument('--use-standardization', action='store_true', default=True,
|
||
help='是否使用标准化')
|
||
subparser.add_argument('--batch-size', type=int, default=10000,
|
||
help='批处理大小')
|
||
|
||
# SVM专用参数
|
||
subparser.add_argument('--svm-c', type=float, default=1.0,
|
||
help='SVM正则化参数C (SVM方法)')
|
||
subparser.add_argument('--svm-kernel', choices=['linear', 'rbf', 'poly', 'sigmoid'], default='rbf',
|
||
help='SVM核函数 (SVM方法)')
|
||
subparser.add_argument('--svm-gamma',
|
||
help='SVM核函数参数gamma: "scale", "auto" 或数值 (SVM方法,默认scale)')
|
||
|
||
# 随机森林专用参数
|
||
subparser.add_argument('--rf-n-estimators', type=int, default=100,
|
||
help='树的数量 (随机森林方法)')
|
||
subparser.add_argument('--rf-max-depth', type=int,
|
||
help='树的最大深度 (随机森林方法)')
|
||
subparser.add_argument('--rf-min-samples-split', type=int, default=2,
|
||
help='内部节点再划分所需最小样本数 (随机森林方法)')
|
||
|
||
# KNN专用参数
|
||
subparser.add_argument('--knn-n-neighbors', type=int, default=5,
|
||
help='邻居数量 (KNN方法)')
|
||
subparser.add_argument('--knn-weights', choices=['uniform', 'distance'], default='uniform',
|
||
help='权重函数 (KNN方法)')
|
||
subparser.add_argument('--knn-algorithm', choices=['auto', 'ball_tree', 'kd_tree', 'brute'], default='auto',
|
||
help='算法 (KNN方法)')
|
||
|
||
# 逻辑回归专用参数
|
||
subparser.add_argument('--lr-c', type=float, default=1.0,
|
||
help='正则化强度 (逻辑回归方法)')
|
||
subparser.add_argument('--lr-max-iter', type=int, default=100,
|
||
help='最大迭代次数 (逻辑回归方法)')
|
||
|
||
# PLS-DA专用参数
|
||
subparser.add_argument('--pls-n-components', type=int, default=10,
|
||
help='PLS成分数量 (PLS-DA方法)')
|
||
|
||
# XGBoost专用参数
|
||
subparser.add_argument('--xgb-n-estimators', type=int, default=100,
|
||
help='树的数量 (XGBoost方法)')
|
||
subparser.add_argument('--xgb-max-depth', type=int, default=6,
|
||
help='树的最大深度 (XGBoost方法)')
|
||
subparser.add_argument('--xgb-learning-rate', type=float, default=0.3,
|
||
help='学习率 (XGBoost方法)')
|
||
|
||
# LightGBM专用参数
|
||
subparser.add_argument('--lgb-n-estimators', type=int, default=100,
|
||
help='树的数量 (LightGBM方法)')
|
||
subparser.add_argument('--lgb-max-depth', type=int, default=-1,
|
||
help='树的最大深度 (LightGBM方法,-1表示无限制)')
|
||
subparser.add_argument('--lgb-learning-rate', type=float, default=0.1,
|
||
help='学习率 (LightGBM方法)')
|
||
|
||
# CatBoost专用参数
|
||
subparser.add_argument('--cb-n-estimators', type=int, default=100,
|
||
help='树的数量 (CatBoost方法)')
|
||
subparser.add_argument('--cb-depth', type=int, default=6,
|
||
help='树的最大深度 (CatBoost方法)')
|
||
subparser.add_argument('--cb-learning-rate', type=float, default=0.03,
|
||
help='学习率 (CatBoost方法)')
|
||
|
||
def build_config(self, args):
|
||
from classfication_method.classfication import ClassificationConfig
|
||
|
||
model_params = {}
|
||
|
||
if args.method == 'svm':
|
||
model_params['C'] = getattr(args, 'svm_c', 1.0)
|
||
model_params['kernel'] = getattr(args, 'svm_kernel', 'rbf')
|
||
# 处理gamma参数:如果未指定,使用默认值'scale'
|
||
if hasattr(args, 'svm_gamma') and args.svm_gamma is not None:
|
||
# 如果提供了值,尝试转换为适当的类型
|
||
gamma_str = str(args.svm_gamma).lower()
|
||
if gamma_str in ['scale', 'auto']:
|
||
model_params['gamma'] = gamma_str
|
||
else:
|
||
try:
|
||
model_params['gamma'] = float(args.svm_gamma)
|
||
except ValueError:
|
||
model_params['gamma'] = 'scale' # 如果转换失败,使用默认值
|
||
else:
|
||
model_params['gamma'] = 'scale' # 默认值
|
||
|
||
elif args.method == 'rf':
|
||
model_params['n_estimators'] = getattr(args, 'rf_n_estimators', 100)
|
||
model_params['max_depth'] = getattr(args, 'rf_max_depth', None)
|
||
model_params['min_samples_split'] = getattr(args, 'rf_min_samples_split', 2)
|
||
|
||
elif args.method == 'knn':
|
||
model_params['n_neighbors'] = getattr(args, 'knn_n_neighbors', 5)
|
||
model_params['weights'] = getattr(args, 'knn_weights', 'uniform')
|
||
model_params['algorithm'] = getattr(args, 'knn_algorithm', 'auto')
|
||
|
||
elif args.method == 'lr':
|
||
model_params['C'] = getattr(args, 'lr_c', 1.0)
|
||
model_params['max_iter'] = getattr(args, 'lr_max_iter', 100)
|
||
|
||
elif args.method == 'pls-da':
|
||
model_params['n_components'] = getattr(args, 'pls_n_components', 10)
|
||
|
||
elif args.method == 'xgboost':
|
||
model_params['n_estimators'] = getattr(args, 'xgb_n_estimators', 100)
|
||
model_params['max_depth'] = getattr(args, 'xgb_max_depth', 6)
|
||
model_params['learning_rate'] = getattr(args, 'xgb_learning_rate', 0.3)
|
||
|
||
elif args.method == 'lightgbm':
|
||
model_params['n_estimators'] = getattr(args, 'lgb_n_estimators', 100)
|
||
model_params['max_depth'] = getattr(args, 'lgb_max_depth', -1)
|
||
model_params['learning_rate'] = getattr(args, 'lgb_learning_rate', 0.1)
|
||
|
||
elif args.method == 'catboost':
|
||
model_params['iterations'] = getattr(args, 'cb_n_estimators', 100)
|
||
model_params['depth'] = getattr(args, 'cb_depth', 6)
|
||
model_params['learning_rate'] = getattr(args, 'cb_learning_rate', 0.03)
|
||
|
||
# 确定输出路径
|
||
if hasattr(args, 'output_file') and args.output_file:
|
||
output_path = args.output_file
|
||
else:
|
||
# 如果没有指定output_file,基于output_dir和output_prefix生成
|
||
import os
|
||
output_prefix = getattr(args, 'output_prefix', 'result')
|
||
output_path = os.path.join(args.output_dir, f"{output_prefix}.dat")
|
||
|
||
# 模型类型映射(从命令行参数映射到ClassificationConfig期望的名称)
|
||
model_type_mapping = {
|
||
'svm': 'svm',
|
||
'random_forest': 'random_forest',
|
||
'knn': 'knn',
|
||
'logistic_regression': 'logistic_regression',
|
||
'linear_discriminant': 'linear_discriminant',
|
||
'quadratic_discriminant': 'quadratic_discriminant',
|
||
'plsda': 'plsda',
|
||
'xgboost': 'xgboost',
|
||
'lightgbm': 'lightgbm',
|
||
'catboost': 'catboost'
|
||
}
|
||
|
||
config_model_type = model_type_mapping.get(args.method, args.method)
|
||
|
||
# 构建配置参数
|
||
config_kwargs = {
|
||
'hyperspectral_path': args.input,
|
||
'roi_path': args.roi_file,
|
||
'model_type': config_model_type,
|
||
'model_params': model_params,
|
||
'test_size': args.test_size,
|
||
'output_path': output_path,
|
||
'use_standardization': getattr(args, 'use_standardization', True),
|
||
'batch_size': getattr(args, 'batch_size', 10000)
|
||
}
|
||
|
||
return ClassificationConfig(**config_kwargs)
|
||
|
||
|
||
class ClusteringArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--method', '-m', default='kmeans',
|
||
choices=['kmeans', 'fuzzy-cmeans', 'gmm', 'hierarchical', 'dbscan', 'spectral', 'subspace', 'ensemble'],
|
||
help='聚类方法')
|
||
subparser.add_argument('--n-clusters', type=int, default=5,
|
||
help='聚类数量')
|
||
subparser.add_argument('--batch', action='store_true',
|
||
help='运行所有聚类方法')
|
||
|
||
# 数据预处理参数
|
||
subparser.add_argument('--use-scaling', action='store_true', default=True,
|
||
help='是否使用数据标准化')
|
||
subparser.add_argument('--scaler-type', choices=['standard', 'minmax'], default='standard',
|
||
help='标准化类型')
|
||
subparser.add_argument('--use-pca', action='store_true',
|
||
help='是否使用PCA降维')
|
||
subparser.add_argument('--pca-components', type=int,
|
||
help='PCA降维后的维度数')
|
||
|
||
# K-means专用参数
|
||
subparser.add_argument('--kmeans-init', choices=['k-means++', 'random'], default='k-means++',
|
||
help='初始化方法 (K-means方法)')
|
||
subparser.add_argument('--kmeans-n-init', type=int, default=10,
|
||
help='运行次数,选择最佳结果 (K-means方法)')
|
||
subparser.add_argument('--kmeans-max-iter', type=int, default=300,
|
||
help='最大迭代次数 (K-means方法)')
|
||
subparser.add_argument('--kmeans-batch-size', type=int, default=1000,
|
||
help='批处理大小 (K-means方法)')
|
||
|
||
# 模糊C均值专用参数
|
||
subparser.add_argument('--fcm-m', type=float, default=2.0,
|
||
help='模糊度参数 (模糊C均值方法,默认2.0)')
|
||
subparser.add_argument('--fcm-error', type=float, default=0.005,
|
||
help='终止误差 (模糊C均值方法)')
|
||
subparser.add_argument('--fcm-maxiter', type=int, default=1000,
|
||
help='最大迭代次数 (模糊C均值方法)')
|
||
|
||
# GMM专用参数
|
||
subparser.add_argument('--gmm-covariance-type', choices=['full', 'tied', 'diag', 'spherical'], default='full',
|
||
help='协方差类型 (GMM方法)')
|
||
subparser.add_argument('--gmm-max-iter', type=int, default=100,
|
||
help='最大迭代次数 (GMM方法)')
|
||
|
||
# 层次聚类专用参数
|
||
subparser.add_argument('--hier-linkage', choices=['ward', 'complete', 'average', 'single'], default='ward',
|
||
help='连接方法 (层次聚类方法)')
|
||
subparser.add_argument('--hier-affinity', choices=['euclidean', 'l1', 'l2', 'manhattan', 'cosine'], default='euclidean',
|
||
help='距离度量 (层次聚类方法)')
|
||
|
||
# DBSCAN专用参数
|
||
subparser.add_argument('--dbscan-eps', type=float, default=0.5,
|
||
help='邻域半径 (DBSCAN方法)')
|
||
subparser.add_argument('--dbscan-min-samples', type=int, default=5,
|
||
help='核心点的最小样本数 (DBSCAN方法)')
|
||
subparser.add_argument('--dbscan-algorithm', choices=['auto', 'ball_tree', 'kd_tree', 'brute'], default='auto',
|
||
help='算法 (DBSCAN方法)')
|
||
|
||
# 谱聚类专用参数
|
||
subparser.add_argument('--spectral-affinity', choices=['rbf', 'nearest_neighbors'], default='rbf',
|
||
help='相似度矩阵构建方法 (谱聚类方法)')
|
||
subparser.add_argument('--spectral-gamma', type=float, default=1.0,
|
||
help='RBF核函数的gamma参数 (谱聚类方法)')
|
||
|
||
# 子空间聚类专用参数
|
||
subparser.add_argument('--subspace-n-subspaces', type=int,
|
||
help='子空间数量 (子空间聚类方法,默认等于聚类数)')
|
||
|
||
# 集成聚类专用参数
|
||
subparser.add_argument('--ensemble-voting', choices=['hard', 'soft'], default='hard',
|
||
help='投票方式 (集成聚类方法)')
|
||
|
||
def build_config(self, args):
|
||
from cluster_method.cluster import ClusteringConfig
|
||
|
||
# 如果是批量模式,使用默认配置
|
||
if hasattr(args, 'batch') and args.batch:
|
||
return ClusteringConfig(
|
||
input_path=args.input,
|
||
method='batch', # 特殊标记,表示批量处理
|
||
n_clusters=args.n_clusters,
|
||
output_dir=args.output_dir
|
||
)
|
||
|
||
# 构建基础配置
|
||
config_kwargs = {
|
||
'input_path': args.input,
|
||
'method': args.method,
|
||
'n_clusters': args.n_clusters,
|
||
'output_dir': args.output_dir,
|
||
'use_scaling': getattr(args, 'use_scaling', True),
|
||
'scaler_type': getattr(args, 'scaler_type', 'standard'),
|
||
'use_pca': getattr(args, 'use_pca', False),
|
||
'pca_components': getattr(args, 'pca_components', None)
|
||
}
|
||
|
||
# 构建方法特定参数
|
||
method_params = {}
|
||
if args.method == 'kmeans':
|
||
method_params['init'] = getattr(args, 'kmeans_init', 'k-means++')
|
||
method_params['n_init'] = getattr(args, 'kmeans_n_init', 10)
|
||
method_params['max_iter'] = getattr(args, 'kmeans_max_iter', 300)
|
||
method_params['batch_size'] = getattr(args, 'kmeans_batch_size', 1000)
|
||
|
||
elif args.method == 'fuzzy-cmeans':
|
||
method_params['fuzziness'] = getattr(args, 'fcm_m', 2.0)
|
||
method_params['error'] = getattr(args, 'fcm_error', 0.005)
|
||
method_params['max_iter'] = getattr(args, 'fcm_maxiter', 1000)
|
||
|
||
elif args.method == 'gmm':
|
||
method_params['covariance_type'] = getattr(args, 'gmm_covariance_type', 'full')
|
||
method_params['max_iter'] = getattr(args, 'gmm_max_iter', 200)
|
||
method_params['n_init_attempts'] = getattr(args, 'gmm_n_init_attempts', 10)
|
||
|
||
elif args.method == 'hierarchical':
|
||
method_params['linkage'] = getattr(args, 'hier_linkage', 'ward')
|
||
method_params['affinity'] = getattr(args, 'hier_affinity', 'euclidean')
|
||
|
||
elif args.method == 'dbscan':
|
||
method_params['eps_percentile'] = getattr(args, 'dbscan_eps_percentile', 50)
|
||
method_params['min_samples_factor'] = getattr(args, 'dbscan_min_samples_factor', 0.1)
|
||
method_params['n_neighbors'] = getattr(args, 'dbscan_n_neighbors', 20)
|
||
|
||
elif args.method == 'spectral':
|
||
method_params['affinity'] = getattr(args, 'spectral_affinity', 'nearest_neighbors')
|
||
method_params['n_neighbors'] = getattr(args, 'spectral_n_neighbors', 10)
|
||
method_params['large_dataset_threshold'] = getattr(args, 'spectral_large_dataset_threshold', 2000)
|
||
|
||
elif args.method == 'subspace':
|
||
method_params['n_components_factor'] = getattr(args, 'subspace_n_components_factor', 0.33)
|
||
method_params['max_iter'] = getattr(args, 'subspace_max_iter', 300)
|
||
|
||
elif args.method == 'ensemble':
|
||
method_params['voting'] = getattr(args, 'ensemble_voting', 'hard')
|
||
|
||
if method_params:
|
||
config_kwargs['method_params'] = method_params
|
||
|
||
return ClusteringConfig(**config_kwargs)
|
||
|
||
|
||
class FeatureSelectionArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--method', '-m', default='Spa',
|
||
choices=['Spa', 'Cars', 'Uve', 'GA', 'ReliefF', 'SiPLS', 'Lars', 'RandomFrog'],
|
||
help='特征选择方法')
|
||
subparser.add_argument('--label-column', '-l', required=True,
|
||
help='标签列名')
|
||
subparser.add_argument('--n-features', type=int, default=20,
|
||
help='选择的特征数量')
|
||
subparser.add_argument('--spectral-columns',
|
||
help='光谱数据列范围,如 "1:10" 或 "auto" (不指定则自动检测)')
|
||
|
||
# 输出和可视化参数
|
||
subparser.add_argument('--output-csv', action='store_true', default=True,
|
||
help='是否输出CSV结果文件')
|
||
subparser.add_argument('--save-plots', action='store_true', default=True,
|
||
help='是否保存可视化图表')
|
||
subparser.add_argument('--plot-prefix', default='feature_selection',
|
||
help='图表文件名前缀')
|
||
|
||
# SPA专用参数
|
||
subparser.add_argument('--spa-min-vars', type=int, default=2,
|
||
help='SPA最小变量数')
|
||
subparser.add_argument('--spa-max-vars', type=int, default=50,
|
||
help='SPA最大变量数')
|
||
subparser.add_argument('--spa-autoscaling', type=int, choices=[0, 1], default=1,
|
||
help='SPA自动缩放 (0=关闭, 1=开启)')
|
||
|
||
# CARS专用参数
|
||
subparser.add_argument('--cars-n', type=int, default=50,
|
||
help='CARS蒙特卡洛采样次数')
|
||
subparser.add_argument('--cars-cv', type=int, default=10,
|
||
help='CARS交叉验证折数')
|
||
|
||
# UVE专用参数
|
||
subparser.add_argument('--uve-ncomp', type=int, default=20,
|
||
help='UVE PLS成分数')
|
||
subparser.add_argument('--uve-cv', type=int, default=5,
|
||
help='UVE交叉验证折数')
|
||
|
||
# GA专用参数
|
||
subparser.add_argument('--ga-population-size', type=int, default=10,
|
||
help='GA种群大小')
|
||
subparser.add_argument('--ga-n-generations', type=int, default=50,
|
||
help='GA进化代数')
|
||
subparser.add_argument('--ga-crossover-rate', type=float, default=0.8,
|
||
help='GA交叉概率')
|
||
subparser.add_argument('--ga-mutation-rate', type=float, default=0.1,
|
||
help='GA变异概率')
|
||
|
||
# Relief-F专用参数
|
||
subparser.add_argument('--relief-k', type=int, default=10,
|
||
help='Relief-F最近邻数k')
|
||
subparser.add_argument('--relief-sample-size', type=int,
|
||
help='Relief-F采样大小 (默认使用全部样本)')
|
||
|
||
# SiPLS专用参数
|
||
subparser.add_argument('--sipls-interval-width', type=int, default=10,
|
||
help='SiPLS区间宽度')
|
||
subparser.add_argument('--sipls-cv', type=int, default=5,
|
||
help='SiPLS交叉验证折数')
|
||
|
||
# LAR专用参数
|
||
subparser.add_argument('--lar-cv', type=int, default=5,
|
||
help='LAR交叉验证折数')
|
||
subparser.add_argument('--lar-max-features', type=int,
|
||
help='LAR最大特征数 (默认无限制)')
|
||
|
||
# RandomFrog专用参数
|
||
subparser.add_argument('--random-frog-n-frogs', type=int, default=50,
|
||
help='RandomFrog青蛙数量')
|
||
subparser.add_argument('--random-frog-n-memeplexes', type=int, default=5,
|
||
help='RandomFrog模因数量')
|
||
subparser.add_argument('--random-frog-n-evolution-steps', type=int, default=10,
|
||
help='RandomFrog进化步数')
|
||
subparser.add_argument('--random-frog-n-shuffle-iterations', type=int, default=10,
|
||
help='RandomFrog重排迭代次数')
|
||
subparser.add_argument('--random-frog-cv', type=int, default=5,
|
||
help='RandomFrog交叉验证折数')
|
||
|
||
def _parse_spectral_columns(self, csv_file_path: str, spectral_columns_arg: str, label_column: str) -> list:
|
||
"""
|
||
解析光谱列参数,返回列名列表
|
||
|
||
Args:
|
||
csv_file_path: CSV文件路径
|
||
spectral_columns_arg: 光谱列参数字符串,如 "1:165" 或 "auto"
|
||
label_column: 标签列名
|
||
|
||
Returns:
|
||
光谱列名列表
|
||
"""
|
||
import pandas as pd
|
||
|
||
# 读取CSV文件的列名
|
||
df = pd.read_csv(csv_file_path, nrows=0)
|
||
all_columns = df.columns.tolist()
|
||
|
||
# 找到标签列的索引
|
||
if label_column not in all_columns:
|
||
raise ValueError(f"标签列 '{label_column}' 不存在于CSV文件中")
|
||
|
||
label_index = all_columns.index(label_column)
|
||
|
||
# 获取除标签列外的所有列
|
||
spectral_candidates = [col for i, col in enumerate(all_columns) if i != label_index]
|
||
|
||
if spectral_columns_arg is None or spectral_columns_arg.lower() == 'auto':
|
||
# 自动检测:使用除标签列外的所有列
|
||
return spectral_candidates
|
||
|
||
# 解析范围字符串
|
||
try:
|
||
columns = []
|
||
# 分割多个范围(用逗号分隔)
|
||
for part in spectral_columns_arg.split(','):
|
||
part = part.strip()
|
||
if ':' in part:
|
||
# 范围选择,如 "1:165"
|
||
start, end = part.split(':')
|
||
start = int(start.strip())
|
||
end = int(end.strip())
|
||
|
||
# 转换为从标签列开始的索引
|
||
# 注意:用户指定的索引是从1开始的(第一列光谱数据为1)
|
||
if start < 1 or end < 1:
|
||
raise ValueError(f"列范围 {start}:{end} 无效,必须从1开始")
|
||
|
||
# 检查范围是否有效
|
||
if end > len(spectral_candidates):
|
||
end = len(spectral_candidates)
|
||
|
||
if start > len(spectral_candidates):
|
||
raise ValueError(f"起始列 {start} 超出可用光谱列范围 [1, {len(spectral_candidates)}]")
|
||
|
||
# 获取指定范围的列名
|
||
selected_columns = spectral_candidates[start-1:end] # Python索引从0开始
|
||
columns.extend(selected_columns)
|
||
else:
|
||
# 单个索引
|
||
idx = int(part.strip())
|
||
if idx < 1 or idx > len(spectral_candidates):
|
||
raise ValueError(f"列索引 {idx} 超出范围 [1, {len(spectral_candidates)}]")
|
||
columns.append(spectral_candidates[idx-1])
|
||
|
||
return list(dict.fromkeys(columns)) # 去重并保持顺序
|
||
|
||
except ValueError as e:
|
||
raise ValueError(f"解析光谱列参数 '{spectral_columns_arg}' 时出错: {e}")
|
||
|
||
def build_config(self, args):
|
||
from Feature_Selection_method.feture_select import FeatureSelectionConfig
|
||
import pandas as pd
|
||
|
||
# 构建方法参数
|
||
method_params = {}
|
||
|
||
if args.method == 'Spa':
|
||
method_params['m_min'] = getattr(args, 'spa_min_vars', 2)
|
||
method_params['m_max'] = getattr(args, 'spa_max_vars', 50)
|
||
method_params['autoscaling'] = getattr(args, 'spa_autoscaling', 1)
|
||
|
||
elif args.method == 'Cars':
|
||
method_params['N'] = getattr(args, 'cars_n', 50)
|
||
method_params['f'] = getattr(args, 'n_features', 20)
|
||
method_params['cv'] = getattr(args, 'cars_cv', 10)
|
||
|
||
elif args.method == 'Uve':
|
||
method_params['ncomp'] = getattr(args, 'uve_ncomp', 20)
|
||
method_params['cv'] = getattr(args, 'uve_cv', 5)
|
||
|
||
elif args.method == 'GA':
|
||
method_params['population_size'] = getattr(args, 'ga_population_size', 10)
|
||
method_params['n_generations'] = getattr(args, 'ga_n_generations', 50)
|
||
method_params['crossover_rate'] = getattr(args, 'ga_crossover_rate', 0.8)
|
||
method_params['mutation_rate'] = getattr(args, 'ga_mutation_rate', 0.1)
|
||
|
||
elif args.method == 'ReliefF':
|
||
method_params['k'] = getattr(args, 'relief_k', 10)
|
||
method_params['sample_size'] = getattr(args, 'relief_sample_size', None)
|
||
|
||
elif args.method == 'SiPLS':
|
||
method_params['interval_width'] = getattr(args, 'sipls_interval_width', 10)
|
||
method_params['cv'] = getattr(args, 'sipls_cv', 5)
|
||
|
||
elif args.method == 'Lars':
|
||
method_params['cv'] = getattr(args, 'lar_cv', 5)
|
||
method_params['max_features'] = getattr(args, 'lar_max_features', None)
|
||
|
||
elif args.method == 'RandomFrog':
|
||
# RandomFrog参数设置
|
||
method_params['n_frogs'] = getattr(args, 'random_frog_n_frogs', 50)
|
||
method_params['n_memeplexes'] = getattr(args, 'random_frog_n_memeplexes', 5)
|
||
method_params['n_evolution_steps'] = getattr(args, 'random_frog_n_evolution_steps', 10)
|
||
method_params['n_shuffle_iterations'] = getattr(args, 'random_frog_n_shuffle_iterations', 10)
|
||
method_params['cv'] = getattr(args, 'random_frog_cv', 5)
|
||
|
||
# 解析光谱列
|
||
spectral_columns = self._parse_spectral_columns(args.input, getattr(args, 'spectral_columns', None), args.label_column)
|
||
|
||
# 构建配置参数
|
||
config_kwargs = {
|
||
'csv_file_path': args.input,
|
||
'label_column': args.label_column,
|
||
'method': args.method,
|
||
'output_dir': args.output_dir,
|
||
'method_params': method_params,
|
||
'spectral_columns': spectral_columns,
|
||
'output_csv': getattr(args, 'output_csv', True),
|
||
'save_plots': getattr(args, 'save_plots', True),
|
||
'plot_name_prefix': getattr(args, 'plot_prefix', 'feature_selection')
|
||
}
|
||
|
||
return FeatureSelectionConfig(**config_kwargs)
|
||
|
||
|
||
class SpectralIndexArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--indices', '-idx', nargs='+',
|
||
help='要计算的光谱指数列表 (默认: 全部)')
|
||
subparser.add_argument('--batch', action='store_true',
|
||
help='计算所有指数')
|
||
subparser.add_argument('--png', action='store_true',
|
||
help='生成PNG可视化')
|
||
|
||
# 数据格式参数
|
||
subparser.add_argument('--data-format', choices=['csv', 'envi', 'auto'], default='auto',
|
||
help='数据文件格式')
|
||
subparser.add_argument('--label-column',
|
||
help='标签列名 (CSV格式时需要)')
|
||
subparser.add_argument('--spectral-start',
|
||
help='光谱数据起始列名或波段 (CSV格式时需要)')
|
||
subparser.add_argument('--spectral-end',
|
||
help='光谱数据结束列名或波段 (CSV格式时需要)')
|
||
|
||
# 输出参数
|
||
subparser.add_argument('--spectral-output-format', choices=['csv', 'excel', 'both'], default='both',
|
||
help='输出文件格式')
|
||
subparser.add_argument('--spectral-save-individual', action='store_true',
|
||
help='是否保存单个指数结果')
|
||
subparser.add_argument('--spectral-output-prefix', default='spectral_indices',
|
||
help='输出文件名前缀')
|
||
|
||
# 可视化参数
|
||
subparser.add_argument('--plot-histogram', action='store_true', default=True,
|
||
help='是否绘制直方图')
|
||
subparser.add_argument('--plot-boxplot', action='store_true', default=True,
|
||
help='是否绘制箱线图')
|
||
subparser.add_argument('--plot-correlation', action='store_true',
|
||
help='是否绘制相关性热力图')
|
||
subparser.add_argument('--plot-wavelength', action='store_true',
|
||
help='是否按波长绘制指数分布')
|
||
|
||
# 统计参数
|
||
subparser.add_argument('--calculate-stats', action='store_true', default=True,
|
||
help='是否计算统计信息')
|
||
subparser.add_argument('--stats-percentiles', nargs='+', type=float,
|
||
help='要计算的百分位数列表 (默认: 25, 50, 75)')
|
||
|
||
# 指数计算参数
|
||
subparser.add_argument('--index-csv', help='自定义光谱指数定义CSV文件路径')
|
||
subparser.add_argument('--formula-csv', help='自定义公式定义CSV文件路径')
|
||
|
||
def build_config(self, args):
|
||
from spectral_feature_method.spectral_index import SpectralIndexConfig
|
||
|
||
# 创建基础配置
|
||
config = SpectralIndexConfig.create_quick_analysis(
|
||
data_file_path=args.input,
|
||
indices_to_calculate=args.indices
|
||
)
|
||
|
||
# 更新数据配置
|
||
config.data.data_format = getattr(args, 'data_format', 'auto')
|
||
config.data.label_column = getattr(args, 'label_column', None)
|
||
config.data.spectral_start = getattr(args, 'spectral_start', None)
|
||
config.data.spectral_end = getattr(args, 'spectral_end', None)
|
||
|
||
# 更新输出配置
|
||
config.output.output_format = getattr(args, 'spectral_output_format', 'both')
|
||
config.output.save_individual_indices = getattr(args, 'spectral_save_individual', False)
|
||
config.output.output_prefix = getattr(args, 'spectral_output_prefix', 'spectral_indices')
|
||
|
||
# 更新可视化配置
|
||
config.output.plot_histogram = getattr(args, 'plot_histogram', True)
|
||
config.output.plot_boxplot = getattr(args, 'plot_boxplot', True)
|
||
config.output.plot_correlation = getattr(args, 'plot_correlation', False)
|
||
config.output.plot_wavelength = getattr(args, 'plot_wavelength', False)
|
||
|
||
# 更新统计配置
|
||
config.output.calculate_statistics = getattr(args, 'calculate_stats', True)
|
||
if hasattr(args, 'stats_percentiles'):
|
||
config.output.statistics_percentiles = args.stats_percentiles
|
||
|
||
# 更新指数配置
|
||
if hasattr(args, 'index_csv'):
|
||
config.indices.spectral_index_csv = args.index_csv
|
||
if hasattr(args, 'formula_csv'):
|
||
config.indices.formula_csv = args.formula_csv
|
||
|
||
return config
|
||
|
||
|
||
class PreprocessingArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--method', '-m', default='SS',
|
||
choices=['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave'],
|
||
help='预处理方法')
|
||
subparser.add_argument('--spectral-start-index', type=int, default=1,
|
||
help='CSV文件的谱段起始列索引 (从0开始,默认: 1)')
|
||
subparser.add_argument('--handle-outliers', action='store_true',
|
||
help='处理异常值')
|
||
subparser.add_argument('--outlier-method', default='iqr',
|
||
choices=['iqr', 'isolation-forest', 'lof'],
|
||
help='异常值检测方法')
|
||
|
||
# CSV文件验证参数
|
||
subparser.add_argument('--validate-csv', action='store_true',
|
||
help='启用CSV文件验证')
|
||
subparser.add_argument('--min-rows', type=int, default=1,
|
||
help='CSV文件最小行数 (默认: 1)')
|
||
subparser.add_argument('--min-cols', type=int, default=2,
|
||
help='CSV文件最小列数 (默认: 2)')
|
||
subparser.add_argument('--check-missing-values', action='store_true',
|
||
help='检查缺失值')
|
||
subparser.add_argument('--check-data-types', action='store_true',
|
||
help='检查数据类型一致性')
|
||
subparser.add_argument('--wavelength-column',
|
||
help='波长列名 (用于验证波长数据)')
|
||
|
||
# MA方法参数
|
||
subparser.add_argument('--ma-window', type=int, default=11,
|
||
help='移动平均窗口大小 (默认: 11)')
|
||
# SG方法参数
|
||
subparser.add_argument('--sg-window', type=int, default=15,
|
||
help='Savitzky-Golay窗口大小 (默认: 15)')
|
||
subparser.add_argument('--sg-poly', type=int, default=2,
|
||
help='Savitzky-Golay多项式阶数 (默认: 2)')
|
||
|
||
def build_config(self, args):
|
||
from preprocessing_method.Preprocessing import PreprocessingConfig
|
||
return PreprocessingConfig(
|
||
input_path=args.input,
|
||
method=args.method,
|
||
spectral_start_index=getattr(args, 'spectral_start_index', 1),
|
||
handle_outliers=args.handle_outliers,
|
||
outlier_method=args.outlier_method,
|
||
output_dir=args.output_dir,
|
||
# CSV验证参数
|
||
validate_csv=getattr(args, 'validate_csv', False),
|
||
min_rows=getattr(args, 'min_rows', 1),
|
||
min_cols=getattr(args, 'min_cols', 2),
|
||
check_missing_values=getattr(args, 'check_missing_values', False),
|
||
check_data_types=getattr(args, 'check_data_types', False),
|
||
wavelength_column=getattr(args, 'wavelength_column', None),
|
||
# MA和SG方法参数
|
||
ma_window=getattr(args, 'ma_window', 11),
|
||
sg_window=getattr(args, 'sg_window', 15),
|
||
sg_poly=getattr(args, 'sg_poly', 2)
|
||
)
|
||
|
||
|
||
class ShapeFeatureArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
# 输入数据参数
|
||
subparser.add_argument('--input-type', default='dat',
|
||
choices=['dat', 'shp'],
|
||
help='输入类型 (默认: dat)')
|
||
subparser.add_argument('--hdr-file',
|
||
help='对应的HDR头文件路径(仅在输入类型为dat时使用)')
|
||
subparser.add_argument('--shape-band-index', type=int, default=0,
|
||
help='波段索引 (默认: 0)')
|
||
|
||
# 连通域处理参数
|
||
subparser.add_argument('--connectivity', type=int, choices=[4, 8], default=8,
|
||
help='连通性:4或8 (默认: 8)')
|
||
subparser.add_argument('--min-area', type=int, default=10,
|
||
help='最小区域面积阈值 (默认: 10)')
|
||
|
||
# 分水岭算法参数
|
||
subparser.add_argument('--use-watershed', action='store_true',
|
||
help='使用分水岭算法分离相邻物体')
|
||
subparser.add_argument('--watershed-min-distance', type=int, default=10,
|
||
help='分水岭算法中局部最大值的最小距离 (默认: 10)')
|
||
|
||
# 输出设置参数
|
||
subparser.add_argument('--output-csv', action='store_true', default=True,
|
||
help='输出CSV文件 (默认: True)')
|
||
subparser.add_argument('--save-labeled-image', action='store_true',
|
||
help='保存标记后的图像')
|
||
|
||
def build_config(self, args):
|
||
from spatial_features_method.shape_feature import ShapeFeatureConfig
|
||
config = ShapeFeatureConfig()
|
||
config.input_type = args.input_type
|
||
config.dat_file_path = args.input if args.input_type == 'dat' else None
|
||
config.hdr_file_path = getattr(args, 'hdr_file', None)
|
||
config.shp_file_path = args.input if args.input_type == 'shp' else None
|
||
config.band_index = args.shape_band_index
|
||
config.connectivity = getattr(args, 'connectivity', 8)
|
||
config.min_area = args.min_area
|
||
config.use_watershed = args.use_watershed
|
||
config.watershed_min_distance = getattr(args, 'watershed_min_distance', 10)
|
||
config.output_dir = args.output_dir
|
||
config.output_csv = getattr(args, 'output_csv', True)
|
||
config.save_labeled_image = getattr(args, 'save_labeled_image', False)
|
||
return config
|
||
|
||
|
||
class GLCMArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
# GLCM参数
|
||
subparser.add_argument('--nbit', type=int, default=64,
|
||
help='灰度级数 (默认: 64)')
|
||
subparser.add_argument('--min-gray', type=int, default=0,
|
||
help='最小灰度值 (默认: 0)')
|
||
subparser.add_argument('--max-gray', type=int, default=255,
|
||
help='最大灰度值 (默认: 255)')
|
||
subparser.add_argument('--slide-window', type=int, default=7,
|
||
help='滑动窗口大小,必须为奇数 (默认: 7)')
|
||
|
||
# GLCM计算参数
|
||
subparser.add_argument('--step', type=int, nargs='+', default=[2],
|
||
help='步长距离列表 (默认: [2])')
|
||
subparser.add_argument('--angle', type=float, nargs='+', default=[0],
|
||
help='角度列表(弧度) (默认: [0])')
|
||
|
||
# 数据处理参数
|
||
subparser.add_argument('--band-index', '-b', type=int, default=25,
|
||
help='要处理的波段索引 (默认: 25)')
|
||
|
||
# 输出设置
|
||
subparser.add_argument('--save-dat', action='store_true', default=True,
|
||
help='保存为DAT文件 (默认: True)')
|
||
subparser.add_argument('--save-png', action='store_true',
|
||
help='保存为PNG文件')
|
||
|
||
def build_config(self, args):
|
||
from spatial_features_method.glcm import SpatialFeatureConfig
|
||
|
||
config = SpatialFeatureConfig()
|
||
config.nbit = args.nbit
|
||
config.mi = args.min_gray
|
||
config.ma = args.max_gray
|
||
config.slide_window = args.slide_window
|
||
config.step = args.step
|
||
config.angle = args.angle
|
||
config.band_index = args.band_index
|
||
config.image_path = args.input # 使用命令行输入的文件路径
|
||
config.output_dir = args.output_dir
|
||
config.save_dat = args.save_dat
|
||
config.save_png = args.save_png
|
||
return config
|
||
|
||
|
||
class ColorAnalysisArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--method', '-m', default='CIEDE2000',
|
||
choices=['CIE76', 'CIE94', 'CIEDE2000'],
|
||
help='色差计算方法')
|
||
subparser.add_argument('--standards-file', '-s', required=True,
|
||
help='标准颜色文件路径 (.csv)')
|
||
|
||
def build_config(self, args):
|
||
# 为 DeltaEApp 创建一个模拟的 args 对象
|
||
class MockArgs:
|
||
def __init__(self, **kwargs):
|
||
for key, value in kwargs.items():
|
||
setattr(self, key, value)
|
||
|
||
return MockArgs(
|
||
input=args.input,
|
||
standards_file=args.standards_file,
|
||
method=args.method,
|
||
output_dir=args.output_dir,
|
||
mode='image' # 默认使用图像模式
|
||
)
|
||
|
||
|
||
class FilterArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--filter-type', default='mean',
|
||
choices=['mean', 'median', 'gaussian', 'bilateral', 'opening', 'closing', 'erosion', 'dilation', 'gradient', 'tophat', 'blackhat'],
|
||
help='滤波类型')
|
||
subparser.add_argument('--filter-band-index', type=int, default=50,
|
||
help='波段索引')
|
||
subparser.add_argument('--kernel-size', '-k', type=int, default=3,
|
||
help='核大小')
|
||
|
||
# Smooth_filter 特有参数
|
||
subparser.add_argument('--sigma', '-s', type=float, default=1.0,
|
||
help='高斯标准差 (仅用于gaussian滤波)')
|
||
subparser.add_argument('--sigma-color', type=float, default=75.0,
|
||
help='颜色空间标准差 (仅用于bilateral滤波)')
|
||
subparser.add_argument('--sigma-space', type=float, default=75.0,
|
||
help='坐标空间标准差 (仅用于bilateral滤波)')
|
||
|
||
# morphological_fliter 特有参数
|
||
subparser.add_argument('--se-shape', default='disk',
|
||
choices=['disk', 'square', 'rectangle', 'diamond'],
|
||
help='结构元素形状 (仅用于形态学操作)')
|
||
subparser.add_argument('--se-size', type=int, default=3,
|
||
help='结构元素大小 (仅用于形态学操作)')
|
||
subparser.add_argument('--format-type', default='auto',
|
||
choices=['auto', 'bil', 'bip', 'bsq', 'dat'],
|
||
help='图像格式 (仅用于形态学操作)')
|
||
|
||
def build_config(self, args):
|
||
# 根据滤波类型选择合适的配置类
|
||
if args.filter_type in ['mean', 'median', 'gaussian', 'bilateral']:
|
||
from fliter_method.Smooth_filter import SmoothFilterConfig
|
||
return SmoothFilterConfig(
|
||
input_path=args.input,
|
||
filter_type=args.filter_type,
|
||
band_index=args.filter_band_index,
|
||
kernel_size=args.kernel_size,
|
||
output_dir=args.output_dir,
|
||
sigma=getattr(args, 'sigma', 1.0),
|
||
sigma_color=getattr(args, 'sigma_color', 75.0),
|
||
sigma_space=getattr(args, 'sigma_space', 75.0)
|
||
)
|
||
else:
|
||
from fliter_method.morphological_fliter import MorphologicalFilterConfig
|
||
return MorphologicalFilterConfig(
|
||
input_path=args.input,
|
||
operation=args.filter_type,
|
||
band_index=args.filter_band_index,
|
||
kernel_size=args.kernel_size,
|
||
output_dir=args.output_dir,
|
||
se_shape=getattr(args, 'se_shape', 'disk'),
|
||
se_size=getattr(args, 'se_size', 3),
|
||
format_type=getattr(args, 'format_type', 'auto')
|
||
)
|
||
|
||
|
||
class RegressionArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
# 数据配置参数
|
||
subparser.add_argument('--label-column', '-l', required=True,
|
||
help='标签列名')
|
||
subparser.add_argument('--test-size', type=float, default=0.2,
|
||
help='测试集比例 (默认: 0.2)')
|
||
subparser.add_argument('--random-state', type=int, default=42,
|
||
help='随机种子 (默认: 42)')
|
||
subparser.add_argument('--scale-method', choices=['standard', 'minmax'], default='standard',
|
||
help='数据标准化方法 (默认: standard)')
|
||
|
||
# 模型配置参数
|
||
subparser.add_argument('--models', '-m', nargs='+',
|
||
help='回归模型列表 (默认: 全部)')
|
||
subparser.add_argument('--tune-params', action='store_true',
|
||
help='启用超参数调优')
|
||
subparser.add_argument('--tuning-method', choices=['grid', 'random'], default='grid',
|
||
help='调优方法 (默认: grid)')
|
||
subparser.add_argument('--cv-folds', type=int, default=5,
|
||
help='交叉验证折数 (默认: 5)')
|
||
subparser.add_argument('--random-search-iter', type=int, default=20,
|
||
help='随机搜索迭代次数 (默认: 20)')
|
||
|
||
# 训练配置参数
|
||
subparser.add_argument('--epochs', type=int, default=100,
|
||
help='训练轮数 (默认: 100)')
|
||
subparser.add_argument('--batch-size', type=int, default=32,
|
||
help='批大小 (默认: 32)')
|
||
subparser.add_argument('--learning-rate', type=float, default=0.001,
|
||
help='学习率 (默认: 0.001)')
|
||
|
||
# 输出配置参数
|
||
subparser.add_argument('--save-models', action='store_true', default=True,
|
||
help='保存训练好的模型')
|
||
subparser.add_argument('--plot-results', action='store_true', default=True,
|
||
help='生成结果可视化图表')
|
||
subparser.add_argument('--save-dir', default='models',
|
||
help='模型保存目录 (默认: models)')
|
||
subparser.add_argument('--plot-dir', default='plots',
|
||
help='图表保存目录 (默认: plots)')
|
||
|
||
def build_config(self, args):
|
||
from rgression_method.regression import RegressionConfig
|
||
|
||
# 创建基础配置
|
||
config = RegressionConfig.create_default(
|
||
csv_path=args.input,
|
||
label_column=args.label_column
|
||
)
|
||
|
||
# 设置模型名称
|
||
if getattr(args, 'models', None):
|
||
config.models.model_names = args.models
|
||
|
||
# 根据是否启用调参设置其他参数
|
||
if getattr(args, 'tune_params', False):
|
||
# 启用调参的配置
|
||
config.models.tune_hyperparams = True
|
||
config.models.tuning_method = getattr(args, 'tuning_method', 'grid')
|
||
config.models.cv_folds = getattr(args, 'cv_folds', 5)
|
||
config.models.random_search_iter = getattr(args, 'random_search_iter', 20)
|
||
|
||
config.training.epochs = getattr(args, 'epochs', 100)
|
||
config.training.batch_size = getattr(args, 'batch_size', 32)
|
||
config.training.learning_rate = getattr(args, 'learning_rate', 0.001)
|
||
else:
|
||
# 快速分析配置
|
||
config.models.tune_hyperparams = False
|
||
|
||
# 设置通用参数
|
||
config.data.test_size = getattr(args, 'test_size', 0.2)
|
||
config.data.random_state = getattr(args, 'random_state', 42)
|
||
config.data.scale_method = getattr(args, 'scale_method', 'standard')
|
||
|
||
config.output.save_models = getattr(args, 'save_models', True)
|
||
config.output.plot_results = getattr(args, 'plot_results', True)
|
||
config.output.save_dir = getattr(args, 'save_dir', 'models')
|
||
config.output.plot_dir = getattr(args, 'plot_dir', 'plots')
|
||
|
||
return config
|
||
|
||
|
||
class RegressionPredictionArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
# 必需参数
|
||
subparser.add_argument('--model-path', '-m', required=True,
|
||
help='模型文件路径(单个文件或目录)')
|
||
|
||
# 可选参数
|
||
subparser.add_argument('--mask-path',
|
||
help='遮罩文件路径')
|
||
subparser.add_argument('--use-mask', action='store_true', default=True,
|
||
help='是否使用遮罩 (默认: True)')
|
||
subparser.add_argument('--batch-mode', action='store_true',
|
||
help='批量处理模式')
|
||
subparser.add_argument('--colormap', default='viridis',
|
||
choices=['viridis', 'plasma', 'inferno', 'magma', 'cividis',
|
||
'Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
|
||
'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
|
||
'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn'],
|
||
help='颜色映射 (默认: viridis)')
|
||
subparser.add_argument('--dpi', type=int, default=300,
|
||
help='输出图像DPI (默认: 300)')
|
||
subparser.add_argument('--save-individual', action='store_true', default=True,
|
||
help='保存单个模型预测结果')
|
||
|
||
def build_config(self, args):
|
||
from rgression_method.regression_predict import PredictionConfig
|
||
|
||
return PredictionConfig(
|
||
image_path=args.input,
|
||
mask_path=getattr(args, 'mask_path', None),
|
||
model_path=args.model_path,
|
||
output_dir=args.output_dir,
|
||
use_mask=getattr(args, 'use_mask', True),
|
||
batch_mode=getattr(args, 'batch_mode', False),
|
||
colormap=getattr(args, 'colormap', 'viridis'),
|
||
dpi=getattr(args, 'dpi', 300),
|
||
save_individual=getattr(args, 'save_individual', True)
|
||
)
|
||
|
||
|
||
class DeltaEArgsBuilder(ArgsBuilder):
|
||
def add_common_args(self, subparser):
|
||
"""DeltaE使用自己的参数命名"""
|
||
pass
|
||
|
||
def add_module_specific_args(self, subparser):
|
||
# DeltaE 特定的参数命名
|
||
subparser.add_argument('--mode', required=True, choices=['image', 'pairwise'],
|
||
help='计算模式: image (图像vs标准色) 或 pairwise (两两比较)')
|
||
subparser.add_argument('--input', required=True, help='输入文件路径 (ENVI格式LAB图像 或 CSV文件)')
|
||
subparser.add_argument('--standards', required=True, help='标准颜色CSV文件路径')
|
||
subparser.add_argument('--output-dir', default='./results', help='输出目录路径')
|
||
subparser.add_argument('--output-file', help='输出文件名 (可选,如果不指定则使用默认名称)')
|
||
subparser.add_argument('--create-heatmap', action='store_true', help='生成热图')
|
||
subparser.add_argument('--create-histogram', action='store_true', help='生成直方图')
|
||
subparser.add_argument('--method', default='CIEDE2000',
|
||
choices=['CIE76', 'CIE94', 'CIEDE2000'],
|
||
help='色差计算方法 (默认: CIEDE2000)')
|
||
subparser.add_argument('--standards-range', help='标准色范围 (e.g.: "0,1,2" 或 "0-5" 或 "all")')
|
||
subparser.add_argument('--reference-range', help='参考色范围 (pairwise模式)')
|
||
subparser.add_argument('--target-range', help='目标色范围 (pairwise模式)')
|
||
subparser.add_argument('--no-progress', action='store_true', help='禁用进度条显示')
|
||
|
||
def build_config(self, args):
|
||
# DeltaE直接使用args对象,不需要转换
|
||
return args
|
||
|
||
|
||
class SpectralToColorArgsBuilder(ArgsBuilder):
|
||
def add_common_args(self, subparser):
|
||
"""SpectralToColor添加通用参数(除了input,因为它使用自己的input定义)"""
|
||
# 添加除了 --input 之外的通用参数
|
||
subparser.add_argument('--output-dir', default='./results',
|
||
help='输出目录路径 (默认: ./results)')
|
||
subparser.add_argument('--output-prefix', default='result',
|
||
help='输出文件前缀 (默认: result)')
|
||
subparser.add_argument('--output-file',
|
||
help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)')
|
||
|
||
def add_module_specific_args(self, subparser):
|
||
# SpectralToColor特定的参数(与原始spectral2cie2.py一致)
|
||
subparser.add_argument('--input', required=True, help='输入文件路径 (.hdr高光谱图像 或 .csv文件)')
|
||
subparser.add_argument('-c', '--color_space', default='Lab',
|
||
choices=['XYZ', 'xyY', 'Lab', 'LCH'],
|
||
help='输出颜色空间 (默认: Lab)')
|
||
subparser.add_argument('-o', '--output', required=True,
|
||
help='输出文件路径')
|
||
subparser.add_argument('-f', '--format', default='dat', choices=['csv', 'dat'],
|
||
help='输出格式 (默认: dat)')
|
||
subparser.add_argument('-i', '--illuminant', default='D65',
|
||
choices=['D65', 'D50', 'A', 'F2', 'F7', 'F11'],
|
||
help='标准光源 (默认: D65)')
|
||
subparser.add_argument('-b', '--observer', default='2°',
|
||
choices=['2°', '10°'],
|
||
help='观察者 (默认: 2°)')
|
||
subparser.add_argument('-w', '--wavelength_col',
|
||
help='CSV文件的波长列名 (默认: wavelength)')
|
||
subparser.add_argument('--no_normalize_xyz', action='store_true',
|
||
help='不归一化XYZ值')
|
||
subparser.add_argument('--use_numpy', action='store_true',
|
||
help='使用NumPy实现而不是colour库')
|
||
|
||
def build_config(self, args):
|
||
# SpectralToColor直接使用args对象,但需要添加output_dir属性以保持兼容性
|
||
# 从输出路径中提取目录部分作为output_dir
|
||
if hasattr(args, 'output') and args.output:
|
||
import os
|
||
args.output_dir = os.path.dirname(args.output) or './results'
|
||
else:
|
||
args.output_dir = './results'
|
||
return args
|
||
|
||
|
||
class XYZ2RGBArgsBuilder(ArgsBuilder):
|
||
def add_common_args(self, subparser):
|
||
"""XYZ2RGB添加通用参数(除了input,因为它使用自己的input定义)"""
|
||
# 添加除了 --input 之外的通用参数
|
||
subparser.add_argument('--output-dir', default='./results',
|
||
help='输出目录路径 (默认: ./results)')
|
||
subparser.add_argument('--output-prefix', default='result',
|
||
help='输出文件前缀 (默认: result)')
|
||
subparser.add_argument('--output-file',
|
||
help='直接指定输出文件名 (优先级高于 --output-dir 和 --output-prefix)')
|
||
|
||
def add_module_specific_args(self, subparser):
|
||
# XYZ2RGB特定的参数(与原始XYZ2RGB.py一致)
|
||
subparser.add_argument('--input', required=True,
|
||
help='输入XYZ图像文件路径')
|
||
subparser.add_argument('--output', required=True,
|
||
help='输出RGB图像文件路径')
|
||
subparser.add_argument('--rgb-space', default='sRGB',
|
||
choices=['sRGB', 'Adobe RGB (1998)', 'DCI-P3', 'ITU-R BT.709',
|
||
'ITU-R BT.2020', 'ACES2065-1', 'ACEScg', 'ProPhoto RGB',
|
||
'Apple RGB', 'PAL/SECAM', 'NTSC (1953)'],
|
||
help='RGB颜色空间 (默认: sRGB)')
|
||
subparser.add_argument('--gamma', default='sRGB',
|
||
choices=['none', 'sRGB', 'BT.709', 'gamma_2.2', 'gamma_1.8',
|
||
'gamma_2.4', 'L*', 'BT.1886', 'ST 2084', 'HLG', 'log'],
|
||
help='Gamma校正方法 (默认: sRGB)')
|
||
subparser.add_argument('--output-dtype', default='uint8',
|
||
choices=['uint8', 'uint10', 'uint12', 'uint16', 'int8',
|
||
'int16', 'float16', 'float32', 'float64'],
|
||
help='输出数据类型 (默认: uint8)')
|
||
|
||
def build_config(self, args):
|
||
# XYZ2RGB直接使用args对象,但需要添加output_dir属性以保持兼容性
|
||
# 从输出路径中提取目录部分作为output_dir
|
||
if hasattr(args, 'output') and args.output:
|
||
import os
|
||
args.output_dir = os.path.dirname(args.output) or './results'
|
||
else:
|
||
args.output_dir = './results'
|
||
return args
|
||
|
||
|
||
class SupervisedClassificationArgsBuilder(ArgsBuilder):
|
||
def add_module_specific_args(self, subparser):
|
||
subparser.add_argument('--xml_file', required=True,
|
||
help='ENVI ROI XML文件路径')
|
||
subparser.add_argument('--method', '-m', default='all',
|
||
choices=['all', 'euclidean', 'cosine', 'correlation',
|
||
'information_divergence', 'jm_distance', 'sid_sa'],
|
||
help='分类距离度量方法 (默认: all,使用所有方法)')
|
||
subparser.add_argument('--distance-params', '-p', type=str, default=None,
|
||
help='距离度量方法的超参数,JSON格式字符串,例如: {"jm_distance": {"alpha": 0.7}}')
|
||
subparser.add_argument('--visualize', '-v', action='store_true',
|
||
help='是否生成可视化结果')
|
||
|
||
def build_config(self, args):
|
||
import json
|
||
|
||
# 解析距离参数
|
||
distance_params = None
|
||
if getattr(args, 'distance_params', None):
|
||
try:
|
||
distance_params = json.loads(args.distance_params)
|
||
except json.JSONDecodeError:
|
||
print(f"警告: 无法解析距离参数 '{args.distance_params}',使用默认参数")
|
||
distance_params = None
|
||
|
||
# 返回配置字典,包含所有必要参数
|
||
config = {
|
||
'input_file': args.input,
|
||
'xml_file': args.xml_file, # 直接访问属性,不使用 getattr
|
||
'output_dir': args.output_dir,
|
||
'method': getattr(args, 'method', 'all'),
|
||
'distance_params': distance_params,
|
||
'visualize': getattr(args, 'visualize', False)
|
||
}
|
||
return config
|
||
|
||
|
||
class PROSAILGUIArgsBuilder(ArgsBuilder):
|
||
def add_common_args(self, subparser):
|
||
"""PROSAIL GUI 不需要通用参数,因为它是交互式的"""
|
||
pass
|
||
|
||
def add_module_specific_args(self, subparser):
|
||
# PROSAIL GUI 通常不需要额外的参数,因为它是交互式的
|
||
# 但可以添加一些选项
|
||
subparser.add_argument('--no-gui', action='store_true',
|
||
help='以命令行模式运行 (不启动GUI)')
|
||
subparser.add_argument('--save-default', action='store_true',
|
||
help='保存默认参数到文件')
|
||
|
||
def build_config(self, args):
|
||
# PROSAIL GUI 不需要复杂的配置,但需要添加一些属性以保持兼容性
|
||
# 给args对象添加output_dir属性,避免后续访问错误
|
||
if not hasattr(args, 'output_dir'):
|
||
args.output_dir = './results'
|
||
return args
|
||
|
||
|
||
# 注册表
|
||
REGISTRY: Dict[str, Dict[str, Any]] = {
|
||
# 降维分析
|
||
"dim-reduction": {
|
||
"info": ModuleInfo(
|
||
category="降维分析",
|
||
module_path="Dimensionality_Reduction_method.dimensionality_reduction",
|
||
callable_name="HyperspectralDimReduction",
|
||
config_class="DimensionalityReductionConfig",
|
||
description="高光谱数据降维分析"
|
||
),
|
||
"args_builder": DimensionalityReductionArgsBuilder(),
|
||
"call_sequence": ["__init__", "load_data", "apply_dim_reduction", "save_results"]
|
||
},
|
||
|
||
# 图像分割
|
||
"segmentation": {
|
||
"info": ModuleInfo(
|
||
category="图像分割",
|
||
module_path="segment_method.threshold_Segment",
|
||
callable_name="ThresholdSegmenter",
|
||
config_class="ThresholdSegmentationConfig",
|
||
description="阈值图像分割"
|
||
),
|
||
"args_builder": SegmentationArgsBuilder(),
|
||
"call_sequence": ["__init__", "load_data", "apply_segmentation", "save_results"]
|
||
},
|
||
|
||
# 边缘检测
|
||
"edge-detection": {
|
||
"info": ModuleInfo(
|
||
category="边缘检测",
|
||
module_path="edge_detect_method.edge_detect",
|
||
callable_name="EdgeDetector",
|
||
config_class="EdgeDetectionConfig",
|
||
description="图像边缘检测"
|
||
),
|
||
"args_builder": EdgeDetectionArgsBuilder(),
|
||
"call_sequence": ["__init__", "load_data", "apply_edge_detection", "save_results"]
|
||
},
|
||
|
||
# 异常检测
|
||
"anomaly-detection": {
|
||
"info": ModuleInfo(
|
||
category="异常检测",
|
||
module_path="Anomaly_method.Covariance",
|
||
callable_name="CovarianceAnomalyDetector",
|
||
config_class="CovarianceAnomalyConfig",
|
||
description="异常检测分析"
|
||
),
|
||
"args_builder": AnomalyDetectionArgsBuilder(),
|
||
"call_sequence": ["__init__", "load_data", "run_analysis_from_config", "save_results"]
|
||
},
|
||
|
||
# 分类分析
|
||
"classification": {
|
||
"info": ModuleInfo(
|
||
category="分类分析",
|
||
module_path="classfication_method.classfication",
|
||
callable_name="ConfigurableHyperspectralClassifier",
|
||
config_class="ClassificationConfig",
|
||
description="高光谱图像分类",
|
||
requires_roi_file=True
|
||
),
|
||
"args_builder": ClassificationArgsBuilder(),
|
||
"call_sequence": ["__init__", "run_classification"]
|
||
},
|
||
|
||
# 聚类分析
|
||
"clustering": {
|
||
"info": ModuleInfo(
|
||
category="聚类分析",
|
||
module_path="cluster_method.cluster",
|
||
callable_name="ClusteringManager",
|
||
config_class="ClusteringConfig",
|
||
description="无监督聚类分析"
|
||
),
|
||
"args_builder": ClusteringArgsBuilder(),
|
||
"call_sequence": ["__init__", "fit_predict", "save_clustering_results"]
|
||
},
|
||
|
||
# 监督分类
|
||
"supervised-classification": {
|
||
"info": ModuleInfo(
|
||
category="分类分析",
|
||
module_path="supervize_cluster_method.supervize_cluster",
|
||
callable_name="run_hsi_classification",
|
||
config_class=None, # 这个任务直接使用函数调用
|
||
description="高光谱图像监督分类",
|
||
requires_roi_file=False, # 使用XML文件而不是ROI文件
|
||
requires_data_file=True
|
||
),
|
||
"args_builder": SupervisedClassificationArgsBuilder(),
|
||
"call_sequence": ["call_function"] # 直接调用函数
|
||
},
|
||
|
||
# 特征选择
|
||
"feature-selection": {
|
||
"info": ModuleInfo(
|
||
category="特征选择",
|
||
module_path="Feature_Selection_method.feture_select",
|
||
callable_name="SpectrumFeatureSelector",
|
||
config_class="FeatureSelectionConfig",
|
||
description="光谱特征选择",
|
||
requires_data_file=False
|
||
),
|
||
"args_builder": FeatureSelectionArgsBuilder(),
|
||
"call_sequence": ["__init__", "load_csv_data", "select_features"]
|
||
},
|
||
|
||
# 光谱指数
|
||
"spectral-index": {
|
||
"info": ModuleInfo(
|
||
category="光谱特征",
|
||
module_path="spectral_feature_method.spectral_index",
|
||
callable_name="HyperspectralIndexCalculator",
|
||
config_class="SpectralIndexConfig",
|
||
description="光谱指数计算"
|
||
),
|
||
"args_builder": SpectralIndexArgsBuilder(),
|
||
"call_sequence": ["__init__", "load_hyperspectral_data", "batch_calculate_and_visualize", "save_results"]
|
||
},
|
||
|
||
# 数据预处理
|
||
"preprocessing": {
|
||
"info": ModuleInfo(
|
||
category="预处理",
|
||
module_path="preprocessing_method.Preprocessing",
|
||
callable_name="HyperspectralPreprocessor",
|
||
config_class="PreprocessingConfig",
|
||
description="数据预处理"
|
||
),
|
||
"args_builder": PreprocessingArgsBuilder(),
|
||
"call_sequence": ["__init__", "load_data", "preprocess", "save_data"]
|
||
},
|
||
|
||
# 形状特征
|
||
"shape-features": {
|
||
"info": ModuleInfo(
|
||
category="空间特征",
|
||
module_path="spatial_features_method.shape_feature",
|
||
callable_name="analyze_shape_features",
|
||
description="形状特征分析",
|
||
requires_data_file=False
|
||
),
|
||
"args_builder": ShapeFeatureArgsBuilder(),
|
||
"call_sequence": ["call_function"]
|
||
},
|
||
# GLCM纹理特征提取
|
||
"glcm": {
|
||
"info": ModuleInfo(
|
||
category="空间特征",
|
||
module_path="spatial_features_method.glcm",
|
||
callable_name="run_glcm_analysis",
|
||
description="GLCM纹理特征提取",
|
||
requires_data_file=True
|
||
),
|
||
"args_builder": GLCMArgsBuilder(),
|
||
"call_sequence": ["call_function"]
|
||
},
|
||
|
||
# 颜色分析
|
||
"color-analysis": {
|
||
"info": ModuleInfo(
|
||
category="颜色分析",
|
||
module_path="color_method.DeltaE",
|
||
callable_name="DeltaEApp",
|
||
description="色差计算分析"
|
||
),
|
||
"args_builder": ColorAnalysisArgsBuilder(),
|
||
"call_sequence": ["run"]
|
||
},
|
||
|
||
# 图像滤波
|
||
"filtering": {
|
||
"info": ModuleInfo(
|
||
category="滤波",
|
||
module_path="fliter_method.Smooth_filter",
|
||
callable_name="HyperspectralImageFilter",
|
||
config_class="SmoothFilterConfig",
|
||
description="图像滤波处理"
|
||
),
|
||
"args_builder": FilterArgsBuilder(),
|
||
"call_sequence": ["process_image"]
|
||
},
|
||
|
||
# 回归分析
|
||
"regression": {
|
||
"info": ModuleInfo(
|
||
category="回归分析",
|
||
module_path="rgression_method.regression",
|
||
callable_name="RegressionAnalyzer",
|
||
config_class="RegressionConfig",
|
||
description="回归模型分析",
|
||
requires_data_file=False
|
||
),
|
||
"args_builder": RegressionArgsBuilder(),
|
||
"call_sequence": ["__init__", "run_analysis_from_config"]
|
||
},
|
||
|
||
# 回归预测
|
||
"regression-prediction": {
|
||
"info": ModuleInfo(
|
||
category="回归预测",
|
||
module_path="rgression_method.regression_predict",
|
||
callable_name="RegressionPredictor",
|
||
config_class="PredictionConfig",
|
||
description="使用训练好的回归模型进行高光谱图像预测",
|
||
requires_data_file=True
|
||
),
|
||
"args_builder": RegressionPredictionArgsBuilder(),
|
||
"call_sequence": ["__init__", "run_prediction"]
|
||
},
|
||
|
||
# 色差计算
|
||
"delta-e": {
|
||
"info": ModuleInfo(
|
||
category="颜色分析",
|
||
module_path="color_method.DeltaE",
|
||
callable_name="DeltaEApp",
|
||
description="色差计算分析 (CIE76, CIE94, CIEDE2000)",
|
||
requires_data_file=True
|
||
),
|
||
"args_builder": DeltaEArgsBuilder(),
|
||
"call_sequence": ["run"]
|
||
},
|
||
|
||
# 光谱到颜色转换
|
||
"spectral-to-color": {
|
||
"info": ModuleInfo(
|
||
category="颜色分析",
|
||
module_path="color_method.spectral2cie2",
|
||
callable_name="SpectralToColorConverter",
|
||
description="光谱数据转换为CIE颜色空间 (XYZ, xyY, Lab, LCH)",
|
||
requires_data_file=True
|
||
),
|
||
"args_builder": SpectralToColorArgsBuilder(),
|
||
"call_sequence": ["process_file"]
|
||
},
|
||
|
||
# XYZ到RGB转换
|
||
"xyz-to-rgb": {
|
||
"info": ModuleInfo(
|
||
category="颜色分析",
|
||
module_path="color_method.XYZ2RGB",
|
||
callable_name="XYZ2RGBApp",
|
||
description="XYZ颜色空间转换为RGB",
|
||
requires_data_file=True
|
||
),
|
||
"args_builder": XYZ2RGBArgsBuilder(),
|
||
"call_sequence": ["run"]
|
||
},
|
||
|
||
# PROSAIL模拟器GUI
|
||
"prosail-gui": {
|
||
"info": ModuleInfo(
|
||
category="植被模拟",
|
||
module_path="prosail_method.prosail_gui",
|
||
callable_name="PROSAILSimulator",
|
||
description="PROSAIL植被光谱模拟器GUI",
|
||
requires_data_file=False
|
||
),
|
||
"args_builder": PROSAILGUIArgsBuilder(),
|
||
"call_sequence": ["run_gui"]
|
||
}
|
||
}
|
||
|
||
|
||
def get_module_info(task_name: str) -> Optional[ModuleInfo]:
|
||
"""获取模块信息"""
|
||
task_info = REGISTRY.get(task_name)
|
||
return task_info["info"] if task_info else None
|
||
|
||
|
||
def get_args_builder(task_name: str) -> Optional[ArgsBuilder]:
|
||
"""获取参数构建器"""
|
||
task_info = REGISTRY.get(task_name)
|
||
return task_info["args_builder"] if task_info else None
|
||
|
||
|
||
def get_call_sequence(task_name: str) -> Optional[List[str]]:
|
||
"""获取调用序列"""
|
||
task_info = REGISTRY.get(task_name)
|
||
return task_info.get("call_sequence")
|
||
|
||
|
||
def create_module_parser(subparsers, task_name: str):
|
||
"""为指定任务创建子解析器"""
|
||
task_info = REGISTRY.get(task_name)
|
||
if not task_info:
|
||
return None
|
||
|
||
info = task_info["info"]
|
||
args_builder = task_info["args_builder"]
|
||
|
||
# 创建子解析器
|
||
subparser = subparsers.add_parser(
|
||
task_name,
|
||
help=f"{info.category}: {info.description}",
|
||
description=info.description
|
||
)
|
||
|
||
# 添加通用参数
|
||
args_builder.add_common_args(subparser)
|
||
|
||
# 添加模块特定参数
|
||
args_builder.add_module_specific_args(subparser)
|
||
|
||
return subparser
|
||
|
||
|
||
def dynamic_import_module(module_path: str):
|
||
"""动态导入模块"""
|
||
try:
|
||
return importlib.import_module(module_path)
|
||
except ImportError as e:
|
||
raise ImportError(f"无法导入模块 {module_path}: {e}")
|
||
|
||
|
||
def get_callable_from_module(module, callable_name: str):
|
||
"""从模块获取可调用对象"""
|
||
try:
|
||
return getattr(module, callable_name)
|
||
except AttributeError as e:
|
||
raise AttributeError(f"模块中没有找到可调用对象 {callable_name}: {e}")
|
||
|
||
|
||
def execute_task(task_name: str, args):
|
||
"""执行任务"""
|
||
task_info = REGISTRY.get(task_name)
|
||
if not task_info:
|
||
raise ValueError(f"未知任务: {task_name}")
|
||
|
||
info = task_info["info"]
|
||
args_builder = task_info["args_builder"]
|
||
call_sequence = task_info["call_sequence"]
|
||
|
||
# 构建配置
|
||
config = args_builder.build_config(args)
|
||
|
||
# 构建输出文件名
|
||
if hasattr(args, 'output_file') and args.output_file:
|
||
output_filename = args.output_file
|
||
else:
|
||
# 基于输出目录和前缀构建文件名
|
||
import os
|
||
output_prefix = getattr(args, 'output_prefix', 'result')
|
||
output_filename = os.path.join(args.output_dir, f"{output_prefix}.dat")
|
||
|
||
# 动态导入模块和可调用对象
|
||
if task_name == "anomaly-detection":
|
||
# 异常检测需要根据方法动态选择类
|
||
if args.method == 'covariance':
|
||
module = dynamic_import_module("Anomaly_method.Covariance")
|
||
callable_obj = get_callable_from_module(module, "CovarianceAnomalyDetector")
|
||
elif args.method == 'one-class-svm':
|
||
module = dynamic_import_module("Anomaly_method.One_Class_SVM")
|
||
callable_obj = get_callable_from_module(module, "OneClassSVMAnomalyDetector")
|
||
elif args.method == 'rx':
|
||
module = dynamic_import_module("Anomaly_method.RX")
|
||
callable_obj = get_callable_from_module(module, "RXAnomalyDetector")
|
||
elif args.method == 'squared-loss-probability':
|
||
module = dynamic_import_module("Anomaly_method.squared_loss_probability")
|
||
callable_obj = get_callable_from_module(module, "SquaredLossAnomalyDetector")
|
||
else:
|
||
raise ValueError(f"不支持的异常检测方法: {args.method}")
|
||
elif task_name == "clustering":
|
||
# 聚类任务直接调用run_clustering函数
|
||
from cluster_method.cluster import run_clustering
|
||
# 构建参数
|
||
method_params = getattr(config, 'method_params', {}) if hasattr(config, 'method_params') else {}
|
||
# 调用聚类函数
|
||
result = run_clustering(
|
||
input_file=config.input_path,
|
||
n_clusters=config.n_clusters,
|
||
methods=[config.method],
|
||
output_dir=config.output_dir,
|
||
method_params=method_params,
|
||
random_state=42
|
||
)
|
||
return result
|
||
elif task_name == "delta-e":
|
||
# 色差计算任务
|
||
module = dynamic_import_module("color_method.DeltaE")
|
||
callable_obj = get_callable_from_module(module, "DeltaEApp")
|
||
app_instance = callable_obj()
|
||
result = app_instance.run(config) # 使用build_config返回的对象
|
||
return result
|
||
elif task_name == "spectral-to-color":
|
||
# 光谱到颜色转换任务 - 直接调用main函数
|
||
import sys
|
||
from color_method.spectral2cie2 import main as spectral_main
|
||
|
||
# 保存原始sys.argv
|
||
original_argv = sys.argv
|
||
|
||
# 构造命令行参数
|
||
sys.argv = ['spectral2cie2.py', '--input', config.input, '-o', config.output,
|
||
'-c', getattr(config, 'color_space', 'Lab'),
|
||
'-i', getattr(config, 'illuminant', 'D65'),
|
||
'-b', getattr(config, 'observer', '2°')]
|
||
|
||
if hasattr(config, 'wavelength_col') and config.wavelength_col:
|
||
sys.argv.extend(['-w', config.wavelength_col])
|
||
|
||
if getattr(config, 'no_normalize_xyz', False):
|
||
sys.argv.append('--no_normalize_xyz')
|
||
|
||
if getattr(config, 'use_numpy', False):
|
||
sys.argv.append('--use_numpy')
|
||
|
||
try:
|
||
result = spectral_main()
|
||
finally:
|
||
# 恢复原始sys.argv
|
||
sys.argv = original_argv
|
||
|
||
return result
|
||
elif task_name == "xyz-to-rgb":
|
||
# XYZ到RGB转换任务
|
||
module = dynamic_import_module("color_method.XYZ2RGB")
|
||
callable_obj = get_callable_from_module(module, "XYZ2RGBApp")
|
||
app_instance = callable_obj()
|
||
result = app_instance.run(config) # 使用build_config返回的对象
|
||
return result
|
||
elif task_name == "preprocessing":
|
||
# 预处理任务 - 需要特殊处理load_data调用
|
||
from preprocessing_method.Preprocessing import HyperspectralPreprocessor
|
||
instance = HyperspectralPreprocessor(config)
|
||
|
||
# 确定输出文件路径(基于输入文件类型和方法名称)
|
||
import os
|
||
from pathlib import Path
|
||
input_path = Path(config.input_path)
|
||
method_name = config.method.lower() # 方法名称转为小写用于文件名
|
||
if input_path.suffix.lower() == '.csv':
|
||
# CSV输入,输出CSV
|
||
output_path = os.path.join(config.output_dir, f"{method_name}_{input_path.stem}.csv")
|
||
else:
|
||
# ENVI输入或其他格式,输出ENVI
|
||
output_path = os.path.join(config.output_dir, f"{method_name}_{input_path.stem}.dat")
|
||
|
||
# 调用load_data时传递正确的参数
|
||
instance.load_data(config.input_path, config.spectral_start_index)
|
||
# 执行预处理
|
||
result = instance.preprocess(config.method, output_path)
|
||
return result
|
||
|
||
elif task_name == "filtering":
|
||
# 滤波任务 - 根据滤波类型选择不同的处理方式
|
||
if hasattr(config, 'filter_type'):
|
||
# Smooth_filter类型
|
||
from fliter_method.Smooth_filter import HyperspectralImageFilter
|
||
filter_obj = HyperspectralImageFilter()
|
||
|
||
# 构建参数字典
|
||
kwargs = {
|
||
'kernel_size': config.kernel_size,
|
||
}
|
||
|
||
if config.filter_type == 'gaussian':
|
||
kwargs['sigma'] = config.sigma
|
||
elif config.filter_type == 'bilateral':
|
||
kwargs['sigma_color'] = config.sigma_color
|
||
kwargs['sigma_space'] = config.sigma_space
|
||
|
||
# 生成输出路径
|
||
output_path = os.path.join(config.output_dir, f"{config.filter_type}_filtered")
|
||
|
||
result = filter_obj.process_image(
|
||
input_path=config.input_path,
|
||
output_path=output_path,
|
||
filter_type=config.filter_type,
|
||
band_index=config.band_index,
|
||
**kwargs
|
||
)
|
||
return result
|
||
elif task_name == "prosail-gui":
|
||
# PROSAIL GUI 任务 - 启动GUI应用
|
||
from prosail_method.prosail_gui import PROSAILSimulator, QApplication
|
||
import sys
|
||
|
||
if getattr(config, 'no_gui', False):
|
||
# 命令行模式 - 这里可以添加命令行版本的PROSAIL功能
|
||
print("PROSAIL GUI 命令行模式暂未实现,请使用 GUI 模式:")
|
||
print("python main.py prosail-gui")
|
||
return None
|
||
else:
|
||
# GUI模式
|
||
app = QApplication(sys.argv)
|
||
app.setStyle('Fusion') # Modern style
|
||
window = PROSAILSimulator()
|
||
window.show()
|
||
result = app.exec_()
|
||
return result
|
||
|
||
elif task_name == "morphological_fliter":
|
||
# Morphological_filter类型
|
||
from fliter_method.morphological_fliter import HyperspectralMorphologyProcessor
|
||
processor = HyperspectralMorphologyProcessor()
|
||
|
||
# 加载图像
|
||
data, header = processor.load_hyperspectral_image(config.input_path, config.format_type)
|
||
|
||
# 提取波段
|
||
band_data = processor.extract_band(config.band_index)
|
||
|
||
# 应用形态学操作
|
||
result = processor.apply_morphology_operation(
|
||
band_data, config.operation, config.se_shape, config.se_size
|
||
)
|
||
|
||
# 保存结果
|
||
output_path = os.path.join(config.output_dir, f"{config.operation}_result")
|
||
processor.save_as_envi(result, output_path, f"{config.operation.capitalize()} 处理结果 - 波段 {config.band_index}")
|
||
|
||
return result
|
||
else:
|
||
module = dynamic_import_module(info.module_path)
|
||
callable_obj = get_callable_from_module(module, info.callable_name)
|
||
|
||
# 检查是否是批量处理(针对降维分析)
|
||
is_batch = (task_name == "dim-reduction" and
|
||
hasattr(args, 'batch_methods') and args.batch_methods is not None)
|
||
|
||
if is_batch:
|
||
# 批量处理
|
||
instance = callable_obj(config)
|
||
instance.load_data()
|
||
result = instance.batch_process(args.batch_methods, args.batch_components, args.output_dir)
|
||
return result
|
||
|
||
# 执行调用序列
|
||
result = None
|
||
instance = None
|
||
|
||
for step in call_sequence:
|
||
if step == "__init__":
|
||
instance = callable_obj(config)
|
||
elif step == "call_function":
|
||
# 直接调用函数,传递config对象作为参数
|
||
result = callable_obj(config)
|
||
elif hasattr(instance, step):
|
||
method = getattr(instance, step)
|
||
if step in ["load_data", "read_hyperspectral_data", "load_hyperspectral_data", "load_csv_data"]:
|
||
method()
|
||
elif step in ["apply_dim_reduction", "apply_segmentation", "apply_edge_detection", "detect_anomalies", "select_features", "batch_calculate_and_visualize", "preprocess"]:
|
||
result = method()
|
||
elif step == "run":
|
||
# 对于某些应用类,直接调用run方法并传递参数
|
||
result = method(config)
|
||
elif step == "save_results":
|
||
# 根据任务类型正确调用 save_results
|
||
if task_name == "dim-reduction":
|
||
# result = (reduced_data, band_names)
|
||
method(result[0], result[1], output_filename, args.method)
|
||
elif task_name == "segmentation":
|
||
# result = (segmented_data, threshold)
|
||
method(result[0], result[1], output_filename, args.method)
|
||
elif task_name == "edge-detection":
|
||
# result = (edge_data, info)
|
||
method(result[0], result[1], output_filename, args.method)
|
||
elif task_name == "spectral-index":
|
||
# result = (results, fig), save_results 期望 (results, output_prefix)
|
||
method(result[0], args.output_prefix)
|
||
elif task_name == "anomaly-detection":
|
||
# save_results 只接受可选的 output_path 参数
|
||
method(output_filename)
|
||
else:
|
||
# 默认调用方式
|
||
method(result, output_filename, task_name)
|
||
elif step == "run_analysis_from_config":
|
||
result = method()
|
||
elif step == "calculate_differences":
|
||
result = method()
|
||
elif step == "apply_filter":
|
||
result = method(args.filter_type if hasattr(args, 'filter_type') else 'mean')
|
||
elif step == "save_envi":
|
||
method(output_filename, result)
|
||
else:
|
||
# 其他方法直接调用
|
||
if step == "parse_roi_xml" and hasattr(args, 'roi_file'):
|
||
method(args.roi_file)
|
||
elif step == "extract_training_samples":
|
||
result = method()
|
||
elif step == "train_model":
|
||
method(result[0], result[1])
|
||
elif step == "predict_image":
|
||
method()
|
||
elif step == "run_classification":
|
||
method()
|
||
else:
|
||
method()
|
||
|
||
return result
|