增加模块;增加主调用命令

This commit is contained in:
2026-01-07 16:36:47 +08:00
commit 2d4b170a45
109 changed files with 55763 additions and 0 deletions

View File

@ -0,0 +1,820 @@
import numpy as np
from scipy import signal
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import pandas as pd
import pywt
from copy import deepcopy
import joblib # 用于保存和加载模型
import os
from pathlib import Path
from dataclasses import dataclass
from typing import Optional
@dataclass
class PreprocessingConfig:
"""预处理配置类"""
input_path: Optional[str] = None
method: str = 'SS'
spectral_start_index: int = 1
handle_outliers: bool = False
outlier_method: str = 'iqr'
output_dir: Optional[str] = None
# CSV文件验证参数
validate_csv: bool = False
min_rows: int = 1
min_cols: int = 2
check_missing_values: bool = False
check_data_types: bool = False
wavelength_column: Optional[str] = None
# MA方法参数
ma_window: int = 11
# SG方法参数
sg_window: int = 15
sg_poly: int = 2
def __post_init__(self):
"""参数校验和默认值设置"""
if not self.input_path:
raise ValueError("必须指定输入文件路径(input_path)")
valid_methods = ['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave']
if self.method not in valid_methods:
raise ValueError(f"不支持的预处理方法: {self.method}。支持的方法: {valid_methods}")
valid_outlier_methods = ['iqr', 'isolation-forest', 'lof']
if self.outlier_method not in valid_outlier_methods:
raise ValueError(f"不支持的异常值检测方法: {self.outlier_method}。支持的方法: {valid_outlier_methods}")
if not self.output_dir:
self.output_dir = './results'
try:
import spectral
SPECTRAL_AVAILABLE = True
except ImportError:
SPECTRAL_AVAILABLE = False
print("警告: spectral库不可用将使用内置方法读取数据")
class HyperspectralPreprocessor:
"""
高光谱数据预处理器
支持的文件格式:
- CSV文件: 需要指定光谱数据的起始列名
- ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (需要spectral库)
"""
def __init__(self, config: PreprocessingConfig):
self.config = config
self.data = None
self.wavelengths = None
self.data_shape = None
self.input_format = None # 'csv' 或 'envi'
def load_data(self, file_path, spectral_start_index=None):
"""
加载高光谱数据文件
参数:
file_path: 文件路径
spectral_start_index: 对于CSV文件光谱数据起始列索引从0开始
"""
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.csv':
if spectral_start_index is None:
raise ValueError("对于CSV文件必须指定spectral_start_index参数")
self.input_format = 'csv'
return self._load_csv_data(file_path, spectral_start_index)
else:
# ENVI格式
if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']:
self.input_format = 'envi'
return self._load_envi_data(file_path)
else:
raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件并确保已安装spectral库。")
def _validate_csv_data(self, df, spectral_start_index):
"""验证CSV数据"""
errors = []
warnings = []
# 基本尺寸验证
if len(df) < self.config.min_rows:
errors.append(f"CSV文件行数 ({len(df)}) 少于最小要求 ({self.config.min_rows})")
if len(df.columns) < self.config.min_cols:
errors.append(f"CSV文件列数 ({len(df.columns)}) 少于最小要求 ({self.config.min_cols})")
# 验证光谱起始列索引
if spectral_start_index < 0 or spectral_start_index >= len(df.columns):
errors.append(f"光谱起始列索引 {spectral_start_index} 超出范围 [0, {len(df.columns)-1}]")
# 检查缺失值
if self.config.check_missing_values:
missing_count = df.isnull().sum().sum()
if missing_count > 0:
warnings.append(f"发现 {missing_count} 个缺失值")
# 检查数据类型一致性
if self.config.check_data_types:
spectral_cols = df.columns[spectral_start_index:]
for col in spectral_cols:
try:
pd.to_numeric(df[col], errors='coerce')
except Exception as e:
warnings.append(f"'{col}' 包含非数值数据: {e}")
# 验证波长列
if self.config.wavelength_column:
if self.config.wavelength_column not in df.columns:
errors.append(f"指定的波长列 '{self.config.wavelength_column}' 不存在")
else:
wavelength_data = df[self.config.wavelength_column]
try:
numeric_wavelengths = pd.to_numeric(wavelength_data, errors='coerce')
if numeric_wavelengths.isnull().any():
warnings.append(f"波长列 '{self.config.wavelength_column}' 包含非数值数据")
else:
# 检查波长是否单调递增
if not (numeric_wavelengths.dropna().is_monotonic_increasing or
numeric_wavelengths.dropna().is_monotonic_decreasing):
warnings.append(f"波长列 '{self.config.wavelength_column}' 不是单调的")
except Exception as e:
warnings.append(f"波长列 '{self.config.wavelength_column}' 验证失败: {e}")
# 输出验证结果
if errors:
error_msg = "CSV验证失败:\n" + "\n".join(f" - {error}" for error in errors)
raise ValueError(error_msg)
if warnings:
print("CSV验证警告:")
for warning in warnings:
print(f" - {warning}")
return True
def _load_csv_data(self, file_path, spectral_start_index):
"""加载CSV格式的高光谱数据"""
df = pd.read_csv(file_path)
# CSV数据验证
if self.config.validate_csv:
print("正在验证CSV数据...")
self._validate_csv_data(df, spectral_start_index)
# 验证列索引有效性
if spectral_start_index < 0 or spectral_start_index >= len(df.columns):
raise ValueError(f"列索引 {spectral_start_index} 超出范围 [0, {len(df.columns)-1}]")
# 提取光谱数据
spectral_data = df.iloc[:, spectral_start_index:].values
# 提取波长信息(列名转换为数值)
try:
# 尝试将列名转换为数值作为波长
self.wavelengths = pd.to_numeric(df.columns[spectral_start_index:], errors='coerce').values
# 检查是否有无效的波长值
if np.isnan(self.wavelengths).any():
print("警告: 部分列名无法转换为波长值,将使用列索引作为波长")
self.wavelengths = np.arange(len(self.wavelengths), dtype=float)
except Exception as e:
print(f"波长解析失败: {e},将使用列索引作为波长")
self.wavelengths = np.arange(spectral_data.shape[1], dtype=float)
# 保存数据
self.data = spectral_data
self.data_shape = self.data.shape
print(f"CSV数据加载成功: {self.data_shape}")
print(f"光谱起始列索引: {spectral_start_index}")
print(f"波段数量: {len(self.wavelengths)}")
if len(self.wavelengths) > 0:
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
return self.data
def _load_envi_data(self, file_path):
"""加载ENVI格式的高光谱数据"""
original_path = Path(file_path)
current_path = original_path
# 如果是.hdr文件找到对应的数据文件
if current_path.suffix.lower() == '.hdr':
data_file = current_path.with_suffix('')
if not data_file.exists():
for ext in ['.dat', '.bil', '.bsq', '.bip']:
candidate = current_path.with_suffix(ext)
if candidate.exists():
data_file = candidate
break
current_path = data_file
print(f"使用spectral库加载: {current_path}")
# 使用spectral打开图像
img = spectral.open_image(str(file_path))
# 读取所有波段的数据
self.data = img.load()
# 获取波长信息
if hasattr(img, 'metadata') and 'wavelength' in img.metadata:
wavelength_data = img.metadata['wavelength']
try:
if isinstance(wavelength_data, str):
wavelength_data = wavelength_data.strip('{}[]')
if ',' in wavelength_data:
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()])
else:
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()])
elif isinstance(wavelength_data, list):
self.wavelengths = np.array([float(w) for w in wavelength_data])
else:
self.wavelengths = np.array(wavelength_data, dtype=float)
print(f"spectral解析波长: {len(self.wavelengths)} 个波段")
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
except Exception as e:
print(f"波长解析失败,使用默认值: {e}")
self.wavelengths = np.arange(self.data.shape[2])
else:
self.wavelengths = np.arange(self.data.shape[2])
print(f"警告: spectral未找到波长信息使用默认值: 0-{self.data.shape[2]-1}")
self.data_shape = self.data.shape
print(f"spectral数据加载成功: {self.data_shape}")
print(f"数据类型: {self.data.dtype}")
print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]")
return self.data
def save_data(self, output_path, data=None, wavelengths=None):
"""
保存预处理后的数据
参数:
output_path: 输出文件路径
data: 要保存的数据 (如果为None使用self.data)
wavelengths: 波长信息 (如果为None使用self.wavelengths)
"""
if data is None:
data = self.data
if wavelengths is None:
wavelengths = self.wavelengths
output_path = Path(output_path)
suffix = output_path.suffix.lower()
if suffix == '.csv':
self._save_csv_data(output_path, data, wavelengths)
else:
# ENVI格式
if self.input_format == 'envi' or suffix in ['.bil', '.bsq', '.bip', '.dat']:
self._save_envi_data(output_path, data, wavelengths)
else:
raise ValueError(f"不支持的输出格式: {suffix}")
def _save_csv_data(self, output_path, data, wavelengths):
"""保存为CSV格式"""
if wavelengths is not None:
# 创建DataFrame列名使用波长
df = pd.DataFrame(data, columns=[f"{w:.1f}" for w in wavelengths])
else:
df = pd.DataFrame(data)
df.to_csv(output_path, index=False)
print(f"已保存CSV文件: {output_path}")
def _save_envi_data(self, output_path, data, wavelengths):
"""保存为ENVI格式"""
data_file = output_path.with_suffix('.dat')
hdr_file = output_path.with_suffix('.hdr')
# 保存二进制数据
data.astype('float32').tofile(str(data_file))
# 创建ENVI头文件
if len(data.shape) == 1:
lines = 1
samples = data.shape[0]
bands = 1
elif len(data.shape) == 2:
lines = data.shape[0]
samples = data.shape[1]
bands = 1
else:
lines = data.shape[0]
samples = data.shape[1]
bands = data.shape[2]
header_content = f"""ENVI
samples = {samples}
lines = {lines}
bands = {bands}
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bip
byte order = 0
wavelength units = nm
"""
if wavelengths is not None and len(wavelengths) == bands:
wavelength_str = ', '.join([f"{w:.1f}" for w in wavelengths])
header_content += f"wavelength = {{{wavelength_str}}}\n"
with open(hdr_file, 'w', encoding='utf-8') as f:
f.write(header_content)
print(f"已保存ENVI文件: {data_file}{hdr_file}")
def preprocess(self, method, output_path, **kwargs):
"""
执行预处理并保存结果
参数:
method: 预处理方法名 ('MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave')
output_path: 输出文件路径
**kwargs: 传递给预处理方法的参数
"""
if self.data is None:
raise ValueError("请先加载数据")
# 执行预处理
processed_data = self._apply_preprocessing(method, **kwargs)
# 保存结果
self.save_data(output_path, processed_data, self.wavelengths)
return processed_data
def _apply_preprocessing(self, method, **kwargs):
"""应用预处理方法"""
method_funcs = {
'MMS': self._MMS,
'SS': self._SS,
'CT': self._CT,
'SNV': self._SNV,
'MA': self._MA,
'SG': self._SG,
'D1': self._D1,
'D2': self._D2,
'DT': self._DT,
'MSC': self._MSC,
'wave': self._wave
}
if method not in method_funcs:
raise ValueError(f"不支持的预处理方法: {method}")
return method_funcs[method](**kwargs)
# 预处理方法实现
def _MMS(self, **kwargs):
"""最大最小值归一化"""
print("执行最大最小值归一化...")
scaler = MinMaxScaler()
if len(self.data.shape) == 2:
# CSV格式: (samples, bands)
return scaler.fit_transform(self.data)
else:
# 图像格式: (rows, cols, bands) -> 需要reshape
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
normalized = scaler.fit_transform(reshaped)
return normalized.reshape(original_shape)
def _SS(self, save_path=None, **kwargs):
"""标准化"""
print("执行标准化...")
scaler = StandardScaler()
if len(self.data.shape) == 2:
result = scaler.fit_transform(self.data)
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
result = scaler.fit_transform(reshaped)
result = result.reshape(original_shape)
if save_path:
joblib.dump(scaler, save_path)
print(f"Scaler参数已保存到: {save_path}")
return result
def _CT(self, **kwargs):
"""均值中心化"""
print("执行均值中心化...")
if len(self.data.shape) == 2:
# 2D数据: (samples, bands) - 按行计算均值并中心化
mean_vals = np.mean(self.data, axis=1, keepdims=True)
return self.data - mean_vals
else:
# 3D数据: (rows, cols, bands) - 按每个像素的光谱计算均值并中心化
mean_vals = np.mean(self.data, axis=2, keepdims=True)
return self.data - mean_vals
def _SNV(self, **kwargs):
"""标准正态变换"""
print("执行标准正态变换...")
if len(self.data.shape) != 2:
raise ValueError("SNV方法只支持2D数据")
# 计算每行的均值和标准差
data_average = np.mean(self.data, axis=1, keepdims=True)
data_std = np.std(self.data, axis=1, keepdims=True)
# 避免除零错误
data_std = np.where(data_std == 0, 1, data_std)
# 标准化
return (self.data - data_average) / data_std
def _MA(self, WSZ=11, **kwargs):
"""移动平均平滑"""
print(f"执行移动平均平滑 (窗口大小: {WSZ})...")
output_data = deepcopy(self.data)
if len(self.data.shape) == 2:
for i in range(output_data.shape[0]):
out0 = np.convolve(output_data[i], np.ones(WSZ, dtype=int), 'valid') / WSZ
r = np.arange(1, WSZ - 1, 2)
start = np.cumsum(output_data[i, :WSZ - 1])[::2] / r
stop = (np.cumsum(output_data[i, :-WSZ:-1])[::2] / r)[::-1]
output_data[i] = np.concatenate((start, out0, stop))
else:
for i in range(output_data.shape[0]):
for j in range(output_data.shape[1]):
spectrum = output_data[i, j, :]
out0 = np.convolve(spectrum, np.ones(WSZ, dtype=int), 'valid') / WSZ
r = np.arange(1, WSZ - 1, 2)
start = np.cumsum(spectrum[:WSZ - 1])[::2] / r
stop = (np.cumsum(spectrum[:-WSZ:-1])[::2] / r)[::-1]
output_data[i, j, :] = np.concatenate((start, out0, stop))
return output_data
def _SG(self, w=15, p=2, **kwargs):
"""Savitzky-Golay平滑滤波"""
print(f"执行Savitzky-Golay平滑 (窗口: {w}, 阶数: {p})...")
if len(self.data.shape) == 2:
return signal.savgol_filter(self.data, w, p)
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
filtered = signal.savgol_filter(reshaped, w, p)
return filtered.reshape(original_shape)
def _D1(self, **kwargs):
"""一阶导数"""
print("执行一阶导数...")
if len(self.data.shape) == 2:
n, p = self.data.shape
output_data = np.ones((n, p - 1))
for i in range(n):
output_data[i] = np.diff(self.data[i])
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
n, p = reshaped.shape
diff_data = np.ones((n, p - 1))
for i in range(n):
diff_data[i] = np.diff(reshaped[i])
output_data = diff_data.reshape(original_shape[0], original_shape[1], p - 1)
# 更新波长信息(减少一个波段)
if self.wavelengths is not None and len(self.wavelengths) > 1:
self.wavelengths = self.wavelengths[:-1]
return output_data
def _D2(self, **kwargs):
"""二阶导数"""
print("执行二阶导数...")
if len(self.data.shape) == 2:
# 2D数据: (samples, bands)
# 计算二阶导数:对原数据求两次差分
first_diff = np.diff(self.data, axis=1) # 一阶导数
second_diff = np.diff(first_diff, axis=1) # 二阶导数
output_data = second_diff
elif len(self.data.shape) == 3:
# 3D数据: (rows, cols, bands) - 高光谱图像
# 对bands维度进行差分
first_diff = np.diff(self.data, axis=2) # 一阶导数
second_diff = np.diff(first_diff, axis=2) # 二阶导数
output_data = second_diff
else:
raise ValueError("不支持的数据维度")
# 更新波长信息(减少两个波段)
if self.wavelengths is not None and len(self.wavelengths) > 2:
self.wavelengths = self.wavelengths[:-2]
return output_data
def _DT(self, **kwargs):
"""趋势校正"""
print("执行趋势校正...")
output_data = np.array(self.data)
if len(self.data.shape) == 2:
length = output_data.shape[1]
x = np.asarray(range(length), dtype=np.float32)
l = LinearRegression()
for i in range(output_data.shape[0]):
l.fit(x.reshape(-1, 1), output_data[i].reshape(-1, 1))
k = l.coef_
b = l.intercept_
for j in range(output_data.shape[1]):
output_data[i][j] = output_data[i][j] - (j * k + b)
else:
length = output_data.shape[2]
x = np.asarray(range(length), dtype=np.float32)
l = LinearRegression()
for i in range(output_data.shape[0]):
for j in range(output_data.shape[1]):
spectrum = output_data[i, j, :]
l.fit(x.reshape(-1, 1), spectrum.reshape(-1, 1))
k = l.coef_
b = l.intercept_
for k_idx in range(output_data.shape[2]):
output_data[i, j, k_idx] = output_data[i, j, k_idx] - (k_idx * k + b)
return output_data
def _MSC(self, **kwargs):
"""多元散射校正"""
print("执行多元散射校正...")
if len(self.data.shape) == 2:
n, p = self.data.shape
output_data = np.ones((n, p))
mean = np.mean(self.data, axis=0)
for i in range(n):
y = self.data[i, :]
l = LinearRegression()
l.fit(mean.reshape(-1, 1), y.reshape(-1, 1))
k = l.coef_
b = l.intercept_
output_data[i, :] = (y - b) / k
elif len(self.data.shape) == 3:
# 3D数据: (rows, cols, bands) - 高光谱图像
rows, cols, bands = self.data.shape
output_data = np.zeros((rows, cols, bands), dtype=np.float32)
# 计算整个图像的平均光谱
mean_spectrum = np.mean(self.data, axis=(0, 1)) # 对所有像素求平均
# 对每个像素进行MSC校正
for i in range(rows):
for j in range(cols):
y = self.data[i, j, :]
l = LinearRegression()
l.fit(mean_spectrum.reshape(-1, 1), y.reshape(-1, 1))
k = l.coef_
b = l.intercept_
# 避免除零错误
k = max(k, 1e-10) if k == 0 else k
output_data[i, j, :] = (y - b) / k
else:
raise ValueError("不支持的数据维度")
return output_data
def _wave(self, **kwargs):
"""小波变换"""
print("执行小波变换...")
def wave_single(spectrum):
w = pywt.Wavelet('db8')
maxlev = pywt.dwt_max_level(len(spectrum), w.dec_len)
coeffs = pywt.wavedec(spectrum, 'db8', level=maxlev)
threshold = 0.04
for i in range(1, len(coeffs)):
coeffs[i] = pywt.threshold(coeffs[i], threshold * max(coeffs[i]))
return pywt.waverec(coeffs, 'db8')
if len(self.data.shape) == 2:
output_data = None
for i in range(self.data.shape[0]):
processed = wave_single(self.data[i])
if i == 0:
output_data = processed
else:
output_data = np.vstack((output_data, processed))
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
output_data = None
for i in range(reshaped.shape[0]):
processed = wave_single(reshaped[i])
if i == 0:
output_data = processed
else:
output_data = np.vstack((output_data, processed))
output_data = output_data.reshape(original_shape)
return output_data
# ===== 便捷函数 =====
def preprocess_file(input_file, output_file, method, spectral_start_index=None, **kwargs):
"""
便捷函数:预处理单个文件
参数:
input_file: 输入文件路径
output_file: 输出文件路径
method: 预处理方法
spectral_start_index: CSV文件的谱段起始列索引从0开始
**kwargs: 传递给预处理方法的参数
"""
processor = HyperspectralPreprocessor()
processor.load_data(input_file, spectral_start_index)
return processor.preprocess(method, output_file, **kwargs)
# 保持向后兼容的函数
def MMS(input_spectrum):
"""最大最小值归一化 (向后兼容)"""
return HyperspectralPreprocessor()._MMS(input_spectrum)
def SS(input_spectrum, save_path=None):
"""标准化 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SS(save_path=save_path)
def CT(input_spectrum):
"""均值中心化 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._CT()
def SNV(input_spectrum):
"""标准正态变换 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SNV()
def MA(input_spectrum, WSZ=11):
"""移动平均平滑 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._MA(WSZ=WSZ)
def SG(input_spectrum, w=15, p=2):
"""Savitzky-Golay平滑滤波 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SG(w=w, p=p)
def D1(input_spectrum):
"""一阶导数 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._D1()
def D2(input_spectrum):
"""二阶导数 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._D2()
def DT(input_spectrum):
"""趋势校正 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._DT()
def MSC(input_spectrum):
"""多元散射校正 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._MSC()
def wave(input_spectrum):
"""小波变换 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._wave()
def main():
"""主函数:命令行接口"""
import argparse
parser = argparse.ArgumentParser(
description='高光谱数据预处理工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
1. 处理ENVI格式图像:
python Preprocessing.py input.hdr -m SNV -o output_snv.dat
2. 处理CSV格式数据:
python Preprocessing.py input.csv -c 1 -m MSC -o output_msc.csv
3. 批量处理多种方法:
python Preprocessing.py input.csv -c wavelength_400 -m SNV MSC MA -o results/
支持的预处理方法:
MMS: 最大最小值归一化
SS: 标准差标准化
CT: 中心化
SNV: 标准正态变量变换
MA: 移动平均
SG: Savitzky-Golay滤波
D1: 一阶导数
D2: 二阶导数
DT: 离散变换
MSC: 多重散射校正
wave: 小波变换
"""
)
parser.add_argument('input_file', help='输入文件路径 (.hdr高光谱图像 或 .csv文件)')
parser.add_argument('-c', '--spectral_col', type=int, default=1,
help='CSV文件的谱段起始列索引 (从0开始默认: 1)')
parser.add_argument('-m', '--methods', nargs='+', required=True,
choices=['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave'],
help='预处理方法 (可指定多个)')
parser.add_argument('-o', '--output', required=True,
help='输出文件路径或目录')
# 方法特定的参数
parser.add_argument('--ma_window', type=int, default=11,
help='移动平均窗口大小 (默认: 11)')
parser.add_argument('--sg_window', type=int, default=15,
help='Savitzky-Golay窗口大小 (默认: 15)')
parser.add_argument('--sg_poly', type=int, default=2,
help='Savitzky-Golay多项式阶数 (默认: 2)')
args = parser.parse_args()
try:
print("=" * 60)
print("高光谱数据预处理工具")
print("=" * 60)
print(f"输入文件: {args.input_file}")
print(f"预处理方法: {', '.join(args.methods)}")
print(f"输出路径: {args.output}")
print()
# 创建处理器
processor = HyperspectralPreprocessor()
# 加载数据
print("正在加载数据...")
processor.load_data(args.input_file, spectral_start_index=args.spectral_col)
# 执行预处理
for method in args.methods:
print(f"\n正在执行 {method} 预处理...")
# 构建输出路径
if len(args.methods) == 1:
output_path = args.output
else:
# 多方法时,在输出目录中创建不同文件名
import os
if os.path.isdir(args.output):
base_name = os.path.splitext(os.path.basename(args.input_file))[0]
output_path = os.path.join(args.output, f"{base_name}_{method.lower()}.csv")
else:
# 如果指定的是文件路径,为每个方法添加后缀
base, ext = os.path.splitext(args.output)
output_path = f"{base}_{method.lower()}{ext}"
# 设置方法特定参数
kwargs = {}
if method == 'MA':
kwargs['WSZ'] = args.ma_window
elif method == 'SG':
kwargs['w'] = args.sg_window
kwargs['p'] = args.sg_poly
# 执行预处理
processor.preprocess(method, output_path, **kwargs)
print(f"{method} 预处理完成: {output_path}")
print("\n" + "=" * 60)
print("所有预处理任务完成!")
print("=" * 60)
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())

View File