增加模块;增加主调用命令
This commit is contained in:
728
Feature_Selection_method/batch_feature_selection.py
Normal file
728
Feature_Selection_method/batch_feature_selection.py
Normal file
@ -0,0 +1,728 @@
|
||||
"""
|
||||
批量特征选择工具
|
||||
支持对多个CSV文件或数据集进行批量特征选择
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple, Union
|
||||
import argparse
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
import warnings
|
||||
|
||||
# 导入特征选择模块
|
||||
from feture_select import (
|
||||
FeatureSelectionConfig,
|
||||
select_features_from_csv,
|
||||
select_features_from_data
|
||||
)
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def parse_column_range(column_range: Union[str, int, List[Union[str, int]]], total_columns: int) -> List[int]:
|
||||
"""
|
||||
解析列范围字符串,返回列索引列表
|
||||
|
||||
Args:
|
||||
column_range: 列范围,如 "0:5", "2,4,6-8", [0,1,2] 或单个索引
|
||||
total_columns: 总列数
|
||||
|
||||
Returns:
|
||||
列索引列表
|
||||
"""
|
||||
if isinstance(column_range, (int, np.integer)):
|
||||
# 单个列索引
|
||||
if column_range >= total_columns or column_range < 0:
|
||||
raise ValueError(f"Column index {column_range} out of range [0, {total_columns-1}]")
|
||||
return [column_range]
|
||||
|
||||
elif isinstance(column_range, str):
|
||||
# 解析范围字符串
|
||||
columns = []
|
||||
# 分割多个范围(用逗号分隔)
|
||||
for part in column_range.split(','):
|
||||
part = part.strip()
|
||||
if ':' in part:
|
||||
# 范围选择,如 "0:5"
|
||||
start, end = part.split(':')
|
||||
start = int(start.strip()) if start.strip() else 0
|
||||
end = int(end.strip()) if end.strip() else total_columns
|
||||
if start < 0:
|
||||
start = total_columns + start
|
||||
if end < 0:
|
||||
end = total_columns + end
|
||||
if start >= total_columns or end > total_columns:
|
||||
raise ValueError(f"Range {start}:{end} out of column range [0, {total_columns-1}]")
|
||||
columns.extend(range(start, end))
|
||||
else:
|
||||
# 单个索引
|
||||
idx = int(part.strip())
|
||||
if idx < 0:
|
||||
idx = total_columns + idx
|
||||
if idx >= total_columns or idx < 0:
|
||||
raise ValueError(f"Column index {idx} out of range [0, {total_columns-1}]")
|
||||
columns.append(idx)
|
||||
return list(set(columns)) # 去重
|
||||
|
||||
elif isinstance(column_range, (list, tuple)):
|
||||
# 直接的列索引列表
|
||||
columns = []
|
||||
for idx in column_range:
|
||||
if isinstance(idx, str):
|
||||
if ':' in idx:
|
||||
# 处理列表中的范围字符串
|
||||
start, end = idx.split(':')
|
||||
start = int(start.strip()) if start.strip() else 0
|
||||
end = int(end.strip()) if end.strip() else total_columns
|
||||
if start < 0:
|
||||
start = total_columns + start
|
||||
if end < 0:
|
||||
end = total_columns + end
|
||||
if start >= total_columns or end > total_columns:
|
||||
raise ValueError(f"Range {start}:{end} out of column range [0, {total_columns-1}]")
|
||||
columns.extend(range(start, end))
|
||||
else:
|
||||
idx_int = int(idx.strip())
|
||||
if idx_int < 0:
|
||||
idx_int = total_columns + idx_int
|
||||
if idx_int >= total_columns or idx_int < 0:
|
||||
raise ValueError(f"Column index {idx_int} out of range [0, {total_columns-1}]")
|
||||
columns.append(idx_int)
|
||||
else:
|
||||
if idx < 0:
|
||||
idx = total_columns + idx
|
||||
if idx >= total_columns or idx < 0:
|
||||
raise ValueError(f"Column index {idx} out of range [0, {total_columns-1}]")
|
||||
columns.append(idx)
|
||||
return list(set(columns)) # 去重
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported column range format: {type(column_range)}")
|
||||
|
||||
|
||||
def convert_column_indices_to_names(df: pd.DataFrame, column_indices: List[int]) -> List[str]:
|
||||
"""
|
||||
将列索引转换为列名
|
||||
|
||||
Args:
|
||||
df: DataFrame
|
||||
column_indices: 列索引列表
|
||||
|
||||
Returns:
|
||||
列名列表
|
||||
"""
|
||||
return [df.columns[i] for i in column_indices]
|
||||
|
||||
|
||||
def resolve_spectral_columns(df: pd.DataFrame, spectral_columns: Union[str, List[Union[str, int]], None]) -> List[str]:
|
||||
"""
|
||||
解析光谱列配置,支持列名和列号范围
|
||||
|
||||
Args:
|
||||
df: DataFrame
|
||||
spectral_columns: 光谱列配置
|
||||
|
||||
Returns:
|
||||
光谱列名列表
|
||||
"""
|
||||
if spectral_columns is None:
|
||||
# 默认使用除标签列外的所有列
|
||||
return df.columns.tolist()
|
||||
|
||||
elif isinstance(spectral_columns, str) and spectral_columns == "auto":
|
||||
# 自动检测光谱列(通常是数值列)
|
||||
potential_spectral_cols = []
|
||||
for col in df.columns:
|
||||
if pd.api.types.is_numeric_dtype(df[col]):
|
||||
# 检查是否是连续的数值序列(光谱波段)
|
||||
try:
|
||||
values = pd.to_numeric(df[col], errors='coerce')
|
||||
if values.notna().sum() > len(df) * 0.8: # 至少80%是数值
|
||||
potential_spectral_cols.append(col)
|
||||
except:
|
||||
continue
|
||||
return potential_spectral_cols
|
||||
|
||||
else:
|
||||
# 解析列范围
|
||||
try:
|
||||
column_indices = parse_column_range(spectral_columns, len(df.columns))
|
||||
return convert_column_indices_to_names(df, column_indices)
|
||||
except ValueError as e:
|
||||
print(f"解析光谱列时出错: {e}")
|
||||
print(f"将使用自动检测模式")
|
||||
return resolve_spectral_columns(df, "auto")
|
||||
|
||||
|
||||
def find_csv_files(directory: Union[str, Path], pattern: str = "*.csv") -> List[Path]:
|
||||
"""
|
||||
在目录中查找所有CSV文件
|
||||
|
||||
Args:
|
||||
directory: 搜索目录
|
||||
pattern: 文件匹配模式
|
||||
|
||||
Returns:
|
||||
CSV文件路径列表
|
||||
"""
|
||||
directory = Path(directory)
|
||||
if not directory.exists():
|
||||
raise FileNotFoundError(f"目录不存在: {directory}")
|
||||
|
||||
csv_files = list(directory.glob(pattern))
|
||||
csv_files.sort() # 排序以保证顺序一致性
|
||||
|
||||
print(f"在目录 {directory} 中找到 {len(csv_files)} 个CSV文件")
|
||||
return csv_files
|
||||
|
||||
|
||||
def create_batch_configs(csv_files: List[Path],
|
||||
base_config: FeatureSelectionConfig,
|
||||
output_base_dir: Union[str, Path]) -> List[Tuple[Path, FeatureSelectionConfig]]:
|
||||
"""
|
||||
为每个CSV文件创建配置
|
||||
|
||||
Args:
|
||||
csv_files: CSV文件列表
|
||||
base_config: 基础配置
|
||||
output_base_dir: 输出基础目录
|
||||
|
||||
Returns:
|
||||
(文件路径, 配置) 元组列表
|
||||
"""
|
||||
configs = []
|
||||
output_base_dir = Path(output_base_dir)
|
||||
|
||||
for csv_file in csv_files:
|
||||
try:
|
||||
# 先读取CSV文件来获取列信息
|
||||
df = pd.read_csv(csv_file, nrows=5) # 只读取前5行来获取列信息
|
||||
|
||||
# 解析标签列
|
||||
if isinstance(base_config.label_column, str):
|
||||
if base_config.label_column not in df.columns:
|
||||
print(f"警告: 文件 {csv_file.name} 中不存在标签列 '{base_config.label_column}',将尝试使用第一列")
|
||||
resolved_label_column = df.columns[0]
|
||||
else:
|
||||
resolved_label_column = base_config.label_column
|
||||
else:
|
||||
# 如果是列索引
|
||||
try:
|
||||
resolved_label_column = df.columns[base_config.label_column]
|
||||
except IndexError:
|
||||
print(f"警告: 文件 {csv_file.name} 中的列索引 {base_config.label_column} 超出范围,将使用第一列")
|
||||
resolved_label_column = df.columns[0]
|
||||
|
||||
# 解析光谱列
|
||||
resolved_spectral_columns = resolve_spectral_columns(df, base_config.spectral_columns)
|
||||
|
||||
# 确保标签列不在光谱列中
|
||||
if resolved_label_column in resolved_spectral_columns:
|
||||
resolved_spectral_columns.remove(resolved_label_column)
|
||||
|
||||
if len(resolved_spectral_columns) == 0:
|
||||
print(f"警告: 文件 {csv_file.name} 中没有找到有效的光谱列")
|
||||
continue
|
||||
|
||||
print(f"文件 {csv_file.name}: 标签列='{resolved_label_column}', 光谱列数={len(resolved_spectral_columns)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"读取文件 {csv_file.name} 时出错: {e},跳过此文件")
|
||||
continue
|
||||
|
||||
# 为每个文件创建独立的输出目录
|
||||
file_stem = csv_file.stem
|
||||
file_output_dir = output_base_dir / file_stem
|
||||
file_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 复制基础配置并修改文件特定的参数
|
||||
config = FeatureSelectionConfig(
|
||||
method=base_config.method,
|
||||
method_params=base_config.method_params.copy(),
|
||||
csv_file_path=str(csv_file),
|
||||
label_column=resolved_label_column,
|
||||
spectral_columns=resolved_spectral_columns,
|
||||
output_csv=base_config.output_csv,
|
||||
output_dir=str(file_output_dir),
|
||||
output_filename=f"{file_stem}_selected_features",
|
||||
save_plots=base_config.save_plots,
|
||||
plot_name_prefix=f"{file_stem}_{base_config.method}",
|
||||
plot_dir=str(file_output_dir) if base_config.plot_dir else None
|
||||
)
|
||||
|
||||
configs.append((csv_file, config))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def process_single_file(csv_file: Path, config: FeatureSelectionConfig) -> Dict:
|
||||
"""
|
||||
处理单个CSV文件的特征选择
|
||||
|
||||
Args:
|
||||
csv_file: CSV文件路径
|
||||
config: 特征选择配置
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
result = {
|
||||
'file': str(csv_file),
|
||||
'file_name': csv_file.name,
|
||||
'success': False,
|
||||
'error': None,
|
||||
'n_selected_features': 0,
|
||||
'selected_columns': [],
|
||||
'processing_time': 0,
|
||||
'output_dir': config.output_dir
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
print(f"开始处理文件: {csv_file.name}")
|
||||
|
||||
# 执行特征选择
|
||||
X_selected, y, selected_columns = select_features_from_csv(config)
|
||||
|
||||
# 记录结果
|
||||
result['success'] = True
|
||||
result['n_selected_features'] = X_selected.shape[1]
|
||||
result['selected_columns'] = selected_columns.tolist() if hasattr(selected_columns, 'tolist') else list(selected_columns)
|
||||
result['n_samples'] = X_selected.shape[0]
|
||||
|
||||
print(f"文件 {csv_file.name} 处理完成,选择特征数: {result['n_selected_features']}")
|
||||
|
||||
except Exception as e:
|
||||
result['error'] = str(e)
|
||||
print(f"文件 {csv_file.name} 处理失败: {e}")
|
||||
|
||||
finally:
|
||||
result['processing_time'] = time.time() - start_time
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def batch_feature_selection(csv_files: List[Path],
|
||||
base_config: FeatureSelectionConfig,
|
||||
output_base_dir: Union[str, Path],
|
||||
max_workers: Optional[int] = None,
|
||||
parallel: bool = False) -> List[Dict]:
|
||||
"""
|
||||
批量执行特征选择
|
||||
|
||||
Args:
|
||||
csv_files: CSV文件列表
|
||||
base_config: 基础配置
|
||||
output_base_dir: 输出基础目录
|
||||
max_workers: 最大并行工作数
|
||||
parallel: 是否并行处理
|
||||
|
||||
Returns:
|
||||
处理结果列表
|
||||
"""
|
||||
# 创建配置
|
||||
file_configs = create_batch_configs(csv_files, base_config, output_base_dir)
|
||||
|
||||
results = []
|
||||
|
||||
if parallel and len(file_configs) > 1:
|
||||
# 并行处理
|
||||
print(f"开始并行处理 {len(file_configs)} 个文件 (最大并行数: {max_workers or 'auto'})")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_config = {
|
||||
executor.submit(process_single_file, csv_file, config): (csv_file, config)
|
||||
for csv_file, config in file_configs
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(future_to_config):
|
||||
csv_file, config = future_to_config[future]
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"并行处理失败 {csv_file.name}: {e}")
|
||||
results.append({
|
||||
'file': str(csv_file),
|
||||
'file_name': csv_file.name,
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'processing_time': 0
|
||||
})
|
||||
|
||||
else:
|
||||
# 串行处理
|
||||
print(f"开始串行处理 {len(file_configs)} 个文件")
|
||||
|
||||
for csv_file, config in file_configs:
|
||||
result = process_single_file(csv_file, config)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def save_batch_results(results: List[Dict], output_file: Union[str, Path]):
|
||||
"""
|
||||
保存批量处理结果到文件
|
||||
|
||||
Args:
|
||||
results: 处理结果列表
|
||||
output_file: 输出文件路径
|
||||
"""
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 转换为DataFrame
|
||||
results_df = pd.DataFrame(results)
|
||||
|
||||
# 保存为CSV
|
||||
results_df.to_csv(output_file, index=False, encoding='utf-8')
|
||||
|
||||
print(f"批量处理结果已保存到: {output_file}")
|
||||
|
||||
|
||||
def print_batch_summary(results: List[Dict]):
|
||||
"""
|
||||
打印批量处理摘要
|
||||
|
||||
Args:
|
||||
results: 处理结果列表
|
||||
"""
|
||||
total_files = len(results)
|
||||
successful_files = sum(1 for r in results if r['success'])
|
||||
failed_files = total_files - successful_files
|
||||
|
||||
total_time = sum(r['processing_time'] for r in results)
|
||||
avg_time = total_time / total_files if total_files > 0 else 0
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("批量特征选择处理摘要")
|
||||
print("="*60)
|
||||
print(f"总文件数: {total_files}")
|
||||
print(f"成功处理: {successful_files}")
|
||||
print(f"失败处理: {failed_files}")
|
||||
print(".2f")
|
||||
print(".2f")
|
||||
|
||||
if successful_files > 0:
|
||||
selected_features = [r['n_selected_features'] for r in results if r['success']]
|
||||
print(f"平均选择的特征数: {np.mean(selected_features):.1f} ± {np.std(selected_features):.1f}")
|
||||
|
||||
if failed_files > 0:
|
||||
print(f"\n失败的文件:")
|
||||
for result in results:
|
||||
if not result['success']:
|
||||
print(f" - {result['file_name']}: {result['error']}")
|
||||
|
||||
print("="*60)
|
||||
|
||||
|
||||
def create_example_batch_config() -> FeatureSelectionConfig:
|
||||
"""
|
||||
创建示例批量配置
|
||||
|
||||
Returns:
|
||||
示例配置对象
|
||||
"""
|
||||
return FeatureSelectionConfig(
|
||||
method="CARS", # 可以使用: Cars, Lars, Uve, Spa, GA, ReliefF, RandomFrog, SiPLS
|
||||
method_params={
|
||||
'N': 50, # CARS参数
|
||||
'f': 20,
|
||||
'cv': 10
|
||||
},
|
||||
# 注意: csv_file_path, label_column, spectral_columns 会在处理每个文件时设置
|
||||
output_csv=True,
|
||||
save_plots=True,
|
||||
plot_name_prefix="batch_fs"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='批量特征选择工具')
|
||||
|
||||
# 必需参数
|
||||
parser.add_argument('input_dir', help='包含CSV文件的输入目录')
|
||||
parser.add_argument('output_dir', help='输出目录')
|
||||
|
||||
# 可选参数
|
||||
parser.add_argument('--method', default='CARS',
|
||||
choices=['Cars', 'Lars', 'Uve', 'Spa', 'GA', 'ReliefF', 'RandomFrog', 'SiPLS'],
|
||||
help='特征选择方法 (默认: CARS)')
|
||||
parser.add_argument('--label_column', required=True,
|
||||
help='标签列名或列索引 (例如: "concentration" 或 0)')
|
||||
parser.add_argument('--spectral_columns', required=True,
|
||||
help='光谱列配置,支持: 列名列表 "col1 col2 col3", 列号范围 "1:10", 混合 "2,4,6-8", 或 "auto" 自动检测')
|
||||
parser.add_argument('--parallel', action='store_true', help='启用并行处理')
|
||||
parser.add_argument('--max_workers', type=int, help='最大并行工作数')
|
||||
parser.add_argument('--no_csv_output', action='store_true', help='不输出CSV文件')
|
||||
parser.add_argument('--no_plots', action='store_true', help='不生成可视化图')
|
||||
parser.add_argument('--results_file', default='batch_results.csv', help='结果文件路径')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 解析光谱列参数
|
||||
if args.spectral_columns == "auto":
|
||||
spectral_columns = "auto"
|
||||
elif ':' in str(args.spectral_columns) or ',' in str(args.spectral_columns):
|
||||
# 如果包含范围符号,保持为字符串让后续解析
|
||||
spectral_columns = args.spectral_columns
|
||||
else:
|
||||
# 可能是空格分隔的列名列表
|
||||
spectral_columns = args.spectral_columns.split()
|
||||
|
||||
# 尝试转换标签列为适当类型
|
||||
try:
|
||||
# 如果是数字,转换为整数
|
||||
label_column = int(args.label_column)
|
||||
except ValueError:
|
||||
# 如果不是数字,当作列名
|
||||
label_column = args.label_column
|
||||
|
||||
# 创建基础配置
|
||||
base_config = FeatureSelectionConfig(
|
||||
method=args.method,
|
||||
method_params={}, # 使用默认参数
|
||||
label_column=label_column,
|
||||
spectral_columns=spectral_columns,
|
||||
output_csv=not args.no_csv_output,
|
||||
save_plots=not args.no_plots,
|
||||
plot_name_prefix=f"batch_{args.method}"
|
||||
)
|
||||
|
||||
# 查找CSV文件
|
||||
csv_files = find_csv_files(args.input_dir)
|
||||
if not csv_files:
|
||||
print("未找到CSV文件")
|
||||
return 1
|
||||
|
||||
# 执行批量特征选择
|
||||
results = batch_feature_selection(
|
||||
csv_files=csv_files,
|
||||
base_config=base_config,
|
||||
output_base_dir=args.output_dir,
|
||||
max_workers=args.max_workers,
|
||||
parallel=args.parallel
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
results_file = Path(args.output_dir) / args.results_file
|
||||
save_batch_results(results, results_file)
|
||||
|
||||
# 打印摘要
|
||||
print_batch_summary(results)
|
||||
|
||||
successful = sum(1 for r in results if r['success'])
|
||||
return 0 if successful > 0 else 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"批量处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
|
||||
def example_usage():
|
||||
"""
|
||||
显示使用示例
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("批量特征选择工具 - 使用指南")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n1. 列范围选择功能:")
|
||||
print(" 支持多种列选择方式:")
|
||||
print(" - 列号范围: '1:10' 表示列1到列10")
|
||||
print(" - 混合选择: '2,4,6-8' 表示列2,4,6,7,8")
|
||||
print(" - 自动检测: 'auto' 自动选择数值列作为光谱列")
|
||||
print(" - 列名列表: 'wavelength_400 wavelength_410 wavelength_420'")
|
||||
|
||||
print("\n2. 命令行使用示例:")
|
||||
print(" # 使用列号范围")
|
||||
print(" python batch_feature_selection.py input_dir output_dir --label_column 0 --spectral_columns 1:50")
|
||||
print("")
|
||||
print(" # 使用混合范围")
|
||||
print(" python batch_feature_selection.py input_dir output_dir --label_column concentration --spectral_columns 2,4,6-8")
|
||||
print("")
|
||||
print(" # 自动检测光谱列")
|
||||
print(" python batch_feature_selection.py input_dir output_dir --label_column Label --spectral_columns auto")
|
||||
|
||||
print("\n3. Python代码使用示例:")
|
||||
print("""
|
||||
from batch_feature_selection import batch_feature_selection, create_example_batch_config, find_csv_files
|
||||
|
||||
# 查找CSV文件
|
||||
csv_files = find_csv_files('your/data/directory')
|
||||
|
||||
# 创建配置
|
||||
base_config = create_example_batch_config()
|
||||
base_config.label_column = 'concentration' # 标签列名
|
||||
base_config.spectral_columns = "5:25" # 列5到25作为光谱列
|
||||
|
||||
# 执行批量处理
|
||||
results = batch_feature_selection(
|
||||
csv_files=csv_files,
|
||||
base_config=base_config,
|
||||
output_base_dir='output/directory',
|
||||
parallel=True
|
||||
)
|
||||
""")
|
||||
|
||||
print("\n4. 支持的特征选择方法:")
|
||||
methods = ['CARS', 'Lars', 'Uve', 'Spa', 'GA', 'ReliefF', 'RandomFrog', 'SiPLS']
|
||||
for method in methods:
|
||||
print(f" - {method}")
|
||||
|
||||
print("\n5. 方法参数配置示例:")
|
||||
print("""
|
||||
# CARS方法
|
||||
config.method_params = {'N': 50, 'f': 20, 'cv': 10}
|
||||
|
||||
# UVE方法
|
||||
config.method_params = {'ncomp': 20, 'cv': 5}
|
||||
|
||||
# SPA方法
|
||||
config.method_params = {'m_min': 2, 'm_max': 50, 'autoscaling': 1}
|
||||
""")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
# 查找CSV文件
|
||||
csv_files = find_csv_files("E:\code\spectronon\single_classsfication\data")
|
||||
|
||||
# 定义所有可用的特征选择方法及其参数
|
||||
methods_config = [
|
||||
{
|
||||
'method': 'Cars',
|
||||
'method_params': {'N': 50, 'f': 20, 'cv': 10},
|
||||
'description': 'Competitive Adaptive Reweighted Sampling'
|
||||
},
|
||||
{
|
||||
'method': 'Uve',
|
||||
'method_params': {'ncomp': 20, 'cv': 5},
|
||||
'description': 'Uninformative Variable Elimination'
|
||||
},
|
||||
{
|
||||
'method': 'Spa',
|
||||
'method_params': {'m_min': 2, 'm_max': 50, 'autoscaling': 1},
|
||||
'description': 'Successive Projections Algorithm'
|
||||
},
|
||||
{
|
||||
'method': 'GA',
|
||||
'method_params': {'population_size': 10},
|
||||
'description': 'Genetic Algorithm'
|
||||
},
|
||||
{
|
||||
'method': 'ReliefF',
|
||||
'method_params': {'n_neighbors': 20, 'n_features_to_keep': 20},
|
||||
'description': 'ReliefF Algorithm'
|
||||
},
|
||||
{
|
||||
'method': 'RandomFrog',
|
||||
'method_params': {'n_frogs': 50, 'n_memeplexes': 5, 'n_evolution_steps': 10, 'n_shuffle_iterations': 10, 'cv': 5},
|
||||
'description': 'Random Frog Leaping Algorithm'
|
||||
},
|
||||
{
|
||||
'method': 'SiPLS',
|
||||
'method_params': {'n_intervals_list': [10, 15, 20]},
|
||||
'description': 'Synergy Interval Partial Least Squares'
|
||||
}
|
||||
]
|
||||
|
||||
print("=" * 80)
|
||||
print("开始批量特征选择 - 使用所有可用方法")
|
||||
print(f"找到 {len(csv_files)} 个CSV文件待处理")
|
||||
print(f"将使用 {len(methods_config)} 种特征选择方法")
|
||||
print("=" * 80)
|
||||
|
||||
all_results = {}
|
||||
|
||||
# 为每种方法执行批量特征选择
|
||||
for i, method_cfg in enumerate(methods_config, 1):
|
||||
method_name = method_cfg['method']
|
||||
description = method_cfg['description']
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"方法 {i}/{len(methods_config)}: {method_name}")
|
||||
print(f"描述: {description}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
# 创建该方法的配置
|
||||
method_config = create_example_batch_config()
|
||||
method_config.method = method_name
|
||||
method_config.method_params = method_cfg['method_params']
|
||||
method_config.label_column = 'Label' # 标签列名
|
||||
method_config.spectral_columns = "1:" # 列1到最后作为光谱列
|
||||
method_config.plot_name_prefix = f"{method_name.lower()}_batch_fs"
|
||||
|
||||
# 执行批量处理
|
||||
method_results = batch_feature_selection(
|
||||
csv_files=csv_files,
|
||||
base_config=method_config,
|
||||
output_base_dir=f'E:\\code\\spectronon\\single_classsfication\\Feature_Selection_method\\directory\\{method_name.lower()}_results',
|
||||
parallel=True
|
||||
)
|
||||
|
||||
all_results[method_name] = {
|
||||
'results': method_results,
|
||||
'description': description,
|
||||
'config': method_cfg
|
||||
}
|
||||
|
||||
print(f"✅ {method_name} 方法处理完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {method_name} 方法处理失败: {str(e)}")
|
||||
all_results[method_name] = {
|
||||
'error': str(e),
|
||||
'description': description,
|
||||
'config': method_cfg
|
||||
}
|
||||
|
||||
# 输出汇总结果
|
||||
print(f"\n{'='*80}")
|
||||
print("批量特征选择处理完成汇总")
|
||||
print(f"{'='*80}")
|
||||
|
||||
successful_methods = []
|
||||
failed_methods = []
|
||||
|
||||
for method_name, result in all_results.items():
|
||||
if 'error' in result:
|
||||
failed_methods.append(f"{method_name}: {result['error']}")
|
||||
print(f"❌ {method_name}: 失败 - {result['error']}")
|
||||
else:
|
||||
successful_methods.append(method_name)
|
||||
print(f"✅ {method_name}: 成功")
|
||||
|
||||
print(f"\n总计: {len(successful_methods)}/{len(methods_config)} 种方法成功处理")
|
||||
print(f"成功的方法: {', '.join(successful_methods)}")
|
||||
|
||||
if failed_methods:
|
||||
print(f"失败的方法: {len(failed_methods)} 种")
|
||||
for failed in failed_methods:
|
||||
print(f" - {failed}")
|
||||
|
||||
print(f"\n结果文件保存在: E:\\code\\spectronon\\single_classsfication\\Feature_Selection_method\\directory\\")
|
||||
print("每个方法都有独立的子目录存储结果")
|
||||
# 如果直接运行此脚本,显示使用指南
|
||||
# if __name__ == "__main__":
|
||||
# import sys
|
||||
# if len(sys.argv) == 1:
|
||||
# example_usage()
|
||||
# else:
|
||||
# # 运行主函数进行批量处理
|
||||
# exit(main())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user