Files
HSI/preprocessing_method/Preprocessing.py

821 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from scipy import signal
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import pandas as pd
import pywt
from copy import deepcopy
import joblib # 用于保存和加载模型
import os
from pathlib import Path
from dataclasses import dataclass
from typing import Optional
@dataclass
class PreprocessingConfig:
"""预处理配置类"""
input_path: Optional[str] = None
method: str = 'SS'
spectral_start_index: int = 1
handle_outliers: bool = False
outlier_method: str = 'iqr'
output_dir: Optional[str] = None
# CSV文件验证参数
validate_csv: bool = False
min_rows: int = 1
min_cols: int = 2
check_missing_values: bool = False
check_data_types: bool = False
wavelength_column: Optional[str] = None
# MA方法参数
ma_window: int = 11
# SG方法参数
sg_window: int = 15
sg_poly: int = 2
def __post_init__(self):
"""参数校验和默认值设置"""
if not self.input_path:
raise ValueError("必须指定输入文件路径(input_path)")
valid_methods = ['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave']
if self.method not in valid_methods:
raise ValueError(f"不支持的预处理方法: {self.method}。支持的方法: {valid_methods}")
valid_outlier_methods = ['iqr', 'isolation-forest', 'lof']
if self.outlier_method not in valid_outlier_methods:
raise ValueError(f"不支持的异常值检测方法: {self.outlier_method}。支持的方法: {valid_outlier_methods}")
if not self.output_dir:
self.output_dir = './results'
try:
import spectral
SPECTRAL_AVAILABLE = True
except ImportError:
SPECTRAL_AVAILABLE = False
print("警告: spectral库不可用将使用内置方法读取数据")
class HyperspectralPreprocessor:
"""
高光谱数据预处理器
支持的文件格式:
- CSV文件: 需要指定光谱数据的起始列名
- ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (需要spectral库)
"""
def __init__(self, config: PreprocessingConfig):
self.config = config
self.data = None
self.wavelengths = None
self.data_shape = None
self.input_format = None # 'csv' 或 'envi'
def load_data(self, file_path, spectral_start_index=None):
"""
加载高光谱数据文件
参数:
file_path: 文件路径
spectral_start_index: 对于CSV文件光谱数据起始列索引从0开始
"""
file_path = Path(file_path)
suffix = file_path.suffix.lower()
if suffix == '.csv':
if spectral_start_index is None:
raise ValueError("对于CSV文件必须指定spectral_start_index参数")
self.input_format = 'csv'
return self._load_csv_data(file_path, spectral_start_index)
else:
# ENVI格式
if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']:
self.input_format = 'envi'
return self._load_envi_data(file_path)
else:
raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件并确保已安装spectral库。")
def _validate_csv_data(self, df, spectral_start_index):
"""验证CSV数据"""
errors = []
warnings = []
# 基本尺寸验证
if len(df) < self.config.min_rows:
errors.append(f"CSV文件行数 ({len(df)}) 少于最小要求 ({self.config.min_rows})")
if len(df.columns) < self.config.min_cols:
errors.append(f"CSV文件列数 ({len(df.columns)}) 少于最小要求 ({self.config.min_cols})")
# 验证光谱起始列索引
if spectral_start_index < 0 or spectral_start_index >= len(df.columns):
errors.append(f"光谱起始列索引 {spectral_start_index} 超出范围 [0, {len(df.columns)-1}]")
# 检查缺失值
if self.config.check_missing_values:
missing_count = df.isnull().sum().sum()
if missing_count > 0:
warnings.append(f"发现 {missing_count} 个缺失值")
# 检查数据类型一致性
if self.config.check_data_types:
spectral_cols = df.columns[spectral_start_index:]
for col in spectral_cols:
try:
pd.to_numeric(df[col], errors='coerce')
except Exception as e:
warnings.append(f"'{col}' 包含非数值数据: {e}")
# 验证波长列
if self.config.wavelength_column:
if self.config.wavelength_column not in df.columns:
errors.append(f"指定的波长列 '{self.config.wavelength_column}' 不存在")
else:
wavelength_data = df[self.config.wavelength_column]
try:
numeric_wavelengths = pd.to_numeric(wavelength_data, errors='coerce')
if numeric_wavelengths.isnull().any():
warnings.append(f"波长列 '{self.config.wavelength_column}' 包含非数值数据")
else:
# 检查波长是否单调递增
if not (numeric_wavelengths.dropna().is_monotonic_increasing or
numeric_wavelengths.dropna().is_monotonic_decreasing):
warnings.append(f"波长列 '{self.config.wavelength_column}' 不是单调的")
except Exception as e:
warnings.append(f"波长列 '{self.config.wavelength_column}' 验证失败: {e}")
# 输出验证结果
if errors:
error_msg = "CSV验证失败:\n" + "\n".join(f" - {error}" for error in errors)
raise ValueError(error_msg)
if warnings:
print("CSV验证警告:")
for warning in warnings:
print(f" - {warning}")
return True
def _load_csv_data(self, file_path, spectral_start_index):
"""加载CSV格式的高光谱数据"""
df = pd.read_csv(file_path)
# CSV数据验证
if self.config.validate_csv:
print("正在验证CSV数据...")
self._validate_csv_data(df, spectral_start_index)
# 验证列索引有效性
if spectral_start_index < 0 or spectral_start_index >= len(df.columns):
raise ValueError(f"列索引 {spectral_start_index} 超出范围 [0, {len(df.columns)-1}]")
# 提取光谱数据
spectral_data = df.iloc[:, spectral_start_index:].values
# 提取波长信息(列名转换为数值)
try:
# 尝试将列名转换为数值作为波长
self.wavelengths = pd.to_numeric(df.columns[spectral_start_index:], errors='coerce').values
# 检查是否有无效的波长值
if np.isnan(self.wavelengths).any():
print("警告: 部分列名无法转换为波长值,将使用列索引作为波长")
self.wavelengths = np.arange(len(self.wavelengths), dtype=float)
except Exception as e:
print(f"波长解析失败: {e},将使用列索引作为波长")
self.wavelengths = np.arange(spectral_data.shape[1], dtype=float)
# 保存数据
self.data = spectral_data
self.data_shape = self.data.shape
print(f"CSV数据加载成功: {self.data_shape}")
print(f"光谱起始列索引: {spectral_start_index}")
print(f"波段数量: {len(self.wavelengths)}")
if len(self.wavelengths) > 0:
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
return self.data
def _load_envi_data(self, file_path):
"""加载ENVI格式的高光谱数据"""
original_path = Path(file_path)
current_path = original_path
# 如果是.hdr文件找到对应的数据文件
if current_path.suffix.lower() == '.hdr':
data_file = current_path.with_suffix('')
if not data_file.exists():
for ext in ['.dat', '.bil', '.bsq', '.bip']:
candidate = current_path.with_suffix(ext)
if candidate.exists():
data_file = candidate
break
current_path = data_file
print(f"使用spectral库加载: {current_path}")
# 使用spectral打开图像
img = spectral.open_image(str(file_path))
# 读取所有波段的数据
self.data = img.load()
# 获取波长信息
if hasattr(img, 'metadata') and 'wavelength' in img.metadata:
wavelength_data = img.metadata['wavelength']
try:
if isinstance(wavelength_data, str):
wavelength_data = wavelength_data.strip('{}[]')
if ',' in wavelength_data:
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()])
else:
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()])
elif isinstance(wavelength_data, list):
self.wavelengths = np.array([float(w) for w in wavelength_data])
else:
self.wavelengths = np.array(wavelength_data, dtype=float)
print(f"spectral解析波长: {len(self.wavelengths)} 个波段")
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
except Exception as e:
print(f"波长解析失败,使用默认值: {e}")
self.wavelengths = np.arange(self.data.shape[2])
else:
self.wavelengths = np.arange(self.data.shape[2])
print(f"警告: spectral未找到波长信息使用默认值: 0-{self.data.shape[2]-1}")
self.data_shape = self.data.shape
print(f"spectral数据加载成功: {self.data_shape}")
print(f"数据类型: {self.data.dtype}")
print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]")
return self.data
def save_data(self, output_path, data=None, wavelengths=None):
"""
保存预处理后的数据
参数:
output_path: 输出文件路径
data: 要保存的数据 (如果为None使用self.data)
wavelengths: 波长信息 (如果为None使用self.wavelengths)
"""
if data is None:
data = self.data
if wavelengths is None:
wavelengths = self.wavelengths
output_path = Path(output_path)
suffix = output_path.suffix.lower()
if suffix == '.csv':
self._save_csv_data(output_path, data, wavelengths)
else:
# ENVI格式
if self.input_format == 'envi' or suffix in ['.bil', '.bsq', '.bip', '.dat']:
self._save_envi_data(output_path, data, wavelengths)
else:
raise ValueError(f"不支持的输出格式: {suffix}")
def _save_csv_data(self, output_path, data, wavelengths):
"""保存为CSV格式"""
if wavelengths is not None:
# 创建DataFrame列名使用波长
df = pd.DataFrame(data, columns=[f"{w:.1f}" for w in wavelengths])
else:
df = pd.DataFrame(data)
df.to_csv(output_path, index=False)
print(f"已保存CSV文件: {output_path}")
def _save_envi_data(self, output_path, data, wavelengths):
"""保存为ENVI格式"""
data_file = output_path.with_suffix('.dat')
hdr_file = output_path.with_suffix('.hdr')
# 保存二进制数据
data.astype('float32').tofile(str(data_file))
# 创建ENVI头文件
if len(data.shape) == 1:
lines = 1
samples = data.shape[0]
bands = 1
elif len(data.shape) == 2:
lines = data.shape[0]
samples = data.shape[1]
bands = 1
else:
lines = data.shape[0]
samples = data.shape[1]
bands = data.shape[2]
header_content = f"""ENVI
samples = {samples}
lines = {lines}
bands = {bands}
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bip
byte order = 0
wavelength units = nm
"""
if wavelengths is not None and len(wavelengths) == bands:
wavelength_str = ', '.join([f"{w:.1f}" for w in wavelengths])
header_content += f"wavelength = {{{wavelength_str}}}\n"
with open(hdr_file, 'w', encoding='utf-8') as f:
f.write(header_content)
print(f"已保存ENVI文件: {data_file}{hdr_file}")
def preprocess(self, method, output_path, **kwargs):
"""
执行预处理并保存结果
参数:
method: 预处理方法名 ('MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave')
output_path: 输出文件路径
**kwargs: 传递给预处理方法的参数
"""
if self.data is None:
raise ValueError("请先加载数据")
# 执行预处理
processed_data = self._apply_preprocessing(method, **kwargs)
# 保存结果
self.save_data(output_path, processed_data, self.wavelengths)
return processed_data
def _apply_preprocessing(self, method, **kwargs):
"""应用预处理方法"""
method_funcs = {
'MMS': self._MMS,
'SS': self._SS,
'CT': self._CT,
'SNV': self._SNV,
'MA': self._MA,
'SG': self._SG,
'D1': self._D1,
'D2': self._D2,
'DT': self._DT,
'MSC': self._MSC,
'wave': self._wave
}
if method not in method_funcs:
raise ValueError(f"不支持的预处理方法: {method}")
return method_funcs[method](**kwargs)
# 预处理方法实现
def _MMS(self, **kwargs):
"""最大最小值归一化"""
print("执行最大最小值归一化...")
scaler = MinMaxScaler()
if len(self.data.shape) == 2:
# CSV格式: (samples, bands)
return scaler.fit_transform(self.data)
else:
# 图像格式: (rows, cols, bands) -> 需要reshape
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
normalized = scaler.fit_transform(reshaped)
return normalized.reshape(original_shape)
def _SS(self, save_path=None, **kwargs):
"""标准化"""
print("执行标准化...")
scaler = StandardScaler()
if len(self.data.shape) == 2:
result = scaler.fit_transform(self.data)
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
result = scaler.fit_transform(reshaped)
result = result.reshape(original_shape)
if save_path:
joblib.dump(scaler, save_path)
print(f"Scaler参数已保存到: {save_path}")
return result
def _CT(self, **kwargs):
"""均值中心化"""
print("执行均值中心化...")
if len(self.data.shape) == 2:
# 2D数据: (samples, bands) - 按行计算均值并中心化
mean_vals = np.mean(self.data, axis=1, keepdims=True)
return self.data - mean_vals
else:
# 3D数据: (rows, cols, bands) - 按每个像素的光谱计算均值并中心化
mean_vals = np.mean(self.data, axis=2, keepdims=True)
return self.data - mean_vals
def _SNV(self, **kwargs):
"""标准正态变换"""
print("执行标准正态变换...")
if len(self.data.shape) != 2:
raise ValueError("SNV方法只支持2D数据")
# 计算每行的均值和标准差
data_average = np.mean(self.data, axis=1, keepdims=True)
data_std = np.std(self.data, axis=1, keepdims=True)
# 避免除零错误
data_std = np.where(data_std == 0, 1, data_std)
# 标准化
return (self.data - data_average) / data_std
def _MA(self, WSZ=11, **kwargs):
"""移动平均平滑"""
print(f"执行移动平均平滑 (窗口大小: {WSZ})...")
output_data = deepcopy(self.data)
if len(self.data.shape) == 2:
for i in range(output_data.shape[0]):
out0 = np.convolve(output_data[i], np.ones(WSZ, dtype=int), 'valid') / WSZ
r = np.arange(1, WSZ - 1, 2)
start = np.cumsum(output_data[i, :WSZ - 1])[::2] / r
stop = (np.cumsum(output_data[i, :-WSZ:-1])[::2] / r)[::-1]
output_data[i] = np.concatenate((start, out0, stop))
else:
for i in range(output_data.shape[0]):
for j in range(output_data.shape[1]):
spectrum = output_data[i, j, :]
out0 = np.convolve(spectrum, np.ones(WSZ, dtype=int), 'valid') / WSZ
r = np.arange(1, WSZ - 1, 2)
start = np.cumsum(spectrum[:WSZ - 1])[::2] / r
stop = (np.cumsum(spectrum[:-WSZ:-1])[::2] / r)[::-1]
output_data[i, j, :] = np.concatenate((start, out0, stop))
return output_data
def _SG(self, w=15, p=2, **kwargs):
"""Savitzky-Golay平滑滤波"""
print(f"执行Savitzky-Golay平滑 (窗口: {w}, 阶数: {p})...")
if len(self.data.shape) == 2:
return signal.savgol_filter(self.data, w, p)
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
filtered = signal.savgol_filter(reshaped, w, p)
return filtered.reshape(original_shape)
def _D1(self, **kwargs):
"""一阶导数"""
print("执行一阶导数...")
if len(self.data.shape) == 2:
n, p = self.data.shape
output_data = np.ones((n, p - 1))
for i in range(n):
output_data[i] = np.diff(self.data[i])
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
n, p = reshaped.shape
diff_data = np.ones((n, p - 1))
for i in range(n):
diff_data[i] = np.diff(reshaped[i])
output_data = diff_data.reshape(original_shape[0], original_shape[1], p - 1)
# 更新波长信息(减少一个波段)
if self.wavelengths is not None and len(self.wavelengths) > 1:
self.wavelengths = self.wavelengths[:-1]
return output_data
def _D2(self, **kwargs):
"""二阶导数"""
print("执行二阶导数...")
if len(self.data.shape) == 2:
# 2D数据: (samples, bands)
# 计算二阶导数:对原数据求两次差分
first_diff = np.diff(self.data, axis=1) # 一阶导数
second_diff = np.diff(first_diff, axis=1) # 二阶导数
output_data = second_diff
elif len(self.data.shape) == 3:
# 3D数据: (rows, cols, bands) - 高光谱图像
# 对bands维度进行差分
first_diff = np.diff(self.data, axis=2) # 一阶导数
second_diff = np.diff(first_diff, axis=2) # 二阶导数
output_data = second_diff
else:
raise ValueError("不支持的数据维度")
# 更新波长信息(减少两个波段)
if self.wavelengths is not None and len(self.wavelengths) > 2:
self.wavelengths = self.wavelengths[:-2]
return output_data
def _DT(self, **kwargs):
"""趋势校正"""
print("执行趋势校正...")
output_data = np.array(self.data)
if len(self.data.shape) == 2:
length = output_data.shape[1]
x = np.asarray(range(length), dtype=np.float32)
l = LinearRegression()
for i in range(output_data.shape[0]):
l.fit(x.reshape(-1, 1), output_data[i].reshape(-1, 1))
k = l.coef_
b = l.intercept_
for j in range(output_data.shape[1]):
output_data[i][j] = output_data[i][j] - (j * k + b)
else:
length = output_data.shape[2]
x = np.asarray(range(length), dtype=np.float32)
l = LinearRegression()
for i in range(output_data.shape[0]):
for j in range(output_data.shape[1]):
spectrum = output_data[i, j, :]
l.fit(x.reshape(-1, 1), spectrum.reshape(-1, 1))
k = l.coef_
b = l.intercept_
for k_idx in range(output_data.shape[2]):
output_data[i, j, k_idx] = output_data[i, j, k_idx] - (k_idx * k + b)
return output_data
def _MSC(self, **kwargs):
"""多元散射校正"""
print("执行多元散射校正...")
if len(self.data.shape) == 2:
n, p = self.data.shape
output_data = np.ones((n, p))
mean = np.mean(self.data, axis=0)
for i in range(n):
y = self.data[i, :]
l = LinearRegression()
l.fit(mean.reshape(-1, 1), y.reshape(-1, 1))
k = l.coef_
b = l.intercept_
output_data[i, :] = (y - b) / k
elif len(self.data.shape) == 3:
# 3D数据: (rows, cols, bands) - 高光谱图像
rows, cols, bands = self.data.shape
output_data = np.zeros((rows, cols, bands), dtype=np.float32)
# 计算整个图像的平均光谱
mean_spectrum = np.mean(self.data, axis=(0, 1)) # 对所有像素求平均
# 对每个像素进行MSC校正
for i in range(rows):
for j in range(cols):
y = self.data[i, j, :]
l = LinearRegression()
l.fit(mean_spectrum.reshape(-1, 1), y.reshape(-1, 1))
k = l.coef_
b = l.intercept_
# 避免除零错误
k = max(k, 1e-10) if k == 0 else k
output_data[i, j, :] = (y - b) / k
else:
raise ValueError("不支持的数据维度")
return output_data
def _wave(self, **kwargs):
"""小波变换"""
print("执行小波变换...")
def wave_single(spectrum):
w = pywt.Wavelet('db8')
maxlev = pywt.dwt_max_level(len(spectrum), w.dec_len)
coeffs = pywt.wavedec(spectrum, 'db8', level=maxlev)
threshold = 0.04
for i in range(1, len(coeffs)):
coeffs[i] = pywt.threshold(coeffs[i], threshold * max(coeffs[i]))
return pywt.waverec(coeffs, 'db8')
if len(self.data.shape) == 2:
output_data = None
for i in range(self.data.shape[0]):
processed = wave_single(self.data[i])
if i == 0:
output_data = processed
else:
output_data = np.vstack((output_data, processed))
else:
original_shape = self.data.shape
reshaped = self.data.reshape(-1, original_shape[2])
output_data = None
for i in range(reshaped.shape[0]):
processed = wave_single(reshaped[i])
if i == 0:
output_data = processed
else:
output_data = np.vstack((output_data, processed))
output_data = output_data.reshape(original_shape)
return output_data
# ===== 便捷函数 =====
def preprocess_file(input_file, output_file, method, spectral_start_index=None, **kwargs):
"""
便捷函数:预处理单个文件
参数:
input_file: 输入文件路径
output_file: 输出文件路径
method: 预处理方法
spectral_start_index: CSV文件的谱段起始列索引从0开始
**kwargs: 传递给预处理方法的参数
"""
processor = HyperspectralPreprocessor()
processor.load_data(input_file, spectral_start_index)
return processor.preprocess(method, output_file, **kwargs)
# 保持向后兼容的函数
def MMS(input_spectrum):
"""最大最小值归一化 (向后兼容)"""
return HyperspectralPreprocessor()._MMS(input_spectrum)
def SS(input_spectrum, save_path=None):
"""标准化 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SS(save_path=save_path)
def CT(input_spectrum):
"""均值中心化 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._CT()
def SNV(input_spectrum):
"""标准正态变换 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SNV()
def MA(input_spectrum, WSZ=11):
"""移动平均平滑 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._MA(WSZ=WSZ)
def SG(input_spectrum, w=15, p=2):
"""Savitzky-Golay平滑滤波 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._SG(w=w, p=p)
def D1(input_spectrum):
"""一阶导数 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._D1()
def D2(input_spectrum):
"""二阶导数 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._D2()
def DT(input_spectrum):
"""趋势校正 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._DT()
def MSC(input_spectrum):
"""多元散射校正 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._MSC()
def wave(input_spectrum):
"""小波变换 (向后兼容)"""
processor = HyperspectralPreprocessor()
processor.data = input_spectrum
return processor._wave()
def main():
"""主函数:命令行接口"""
import argparse
parser = argparse.ArgumentParser(
description='高光谱数据预处理工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
1. 处理ENVI格式图像:
python Preprocessing.py input.hdr -m SNV -o output_snv.dat
2. 处理CSV格式数据:
python Preprocessing.py input.csv -c 1 -m MSC -o output_msc.csv
3. 批量处理多种方法:
python Preprocessing.py input.csv -c wavelength_400 -m SNV MSC MA -o results/
支持的预处理方法:
MMS: 最大最小值归一化
SS: 标准差标准化
CT: 中心化
SNV: 标准正态变量变换
MA: 移动平均
SG: Savitzky-Golay滤波
D1: 一阶导数
D2: 二阶导数
DT: 离散变换
MSC: 多重散射校正
wave: 小波变换
"""
)
parser.add_argument('input_file', help='输入文件路径 (.hdr高光谱图像 或 .csv文件)')
parser.add_argument('-c', '--spectral_col', type=int, default=1,
help='CSV文件的谱段起始列索引 (从0开始默认: 1)')
parser.add_argument('-m', '--methods', nargs='+', required=True,
choices=['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave'],
help='预处理方法 (可指定多个)')
parser.add_argument('-o', '--output', required=True,
help='输出文件路径或目录')
# 方法特定的参数
parser.add_argument('--ma_window', type=int, default=11,
help='移动平均窗口大小 (默认: 11)')
parser.add_argument('--sg_window', type=int, default=15,
help='Savitzky-Golay窗口大小 (默认: 15)')
parser.add_argument('--sg_poly', type=int, default=2,
help='Savitzky-Golay多项式阶数 (默认: 2)')
args = parser.parse_args()
try:
print("=" * 60)
print("高光谱数据预处理工具")
print("=" * 60)
print(f"输入文件: {args.input_file}")
print(f"预处理方法: {', '.join(args.methods)}")
print(f"输出路径: {args.output}")
print()
# 创建处理器
processor = HyperspectralPreprocessor()
# 加载数据
print("正在加载数据...")
processor.load_data(args.input_file, spectral_start_index=args.spectral_col)
# 执行预处理
for method in args.methods:
print(f"\n正在执行 {method} 预处理...")
# 构建输出路径
if len(args.methods) == 1:
output_path = args.output
else:
# 多方法时,在输出目录中创建不同文件名
import os
if os.path.isdir(args.output):
base_name = os.path.splitext(os.path.basename(args.input_file))[0]
output_path = os.path.join(args.output, f"{base_name}_{method.lower()}.csv")
else:
# 如果指定的是文件路径,为每个方法添加后缀
base, ext = os.path.splitext(args.output)
output_path = f"{base}_{method.lower()}{ext}"
# 设置方法特定参数
kwargs = {}
if method == 'MA':
kwargs['WSZ'] = args.ma_window
elif method == 'SG':
kwargs['w'] = args.sg_window
kwargs['p'] = args.sg_poly
# 执行预处理
processor.preprocess(method, output_path, **kwargs)
print(f"{method} 预处理完成: {output_path}")
print("\n" + "=" * 60)
print("所有预处理任务完成!")
print("=" * 60)
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())