增加模块;增加主调用命令
This commit is contained in:
820
preprocessing_method/Preprocessing.py
Normal file
820
preprocessing_method/Preprocessing.py
Normal file
@ -0,0 +1,820 @@
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.preprocessing import MinMaxScaler, StandardScaler
|
||||
import pandas as pd
|
||||
import pywt
|
||||
from copy import deepcopy
|
||||
import joblib # 用于保存和加载模型
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@dataclass
|
||||
class PreprocessingConfig:
|
||||
"""预处理配置类"""
|
||||
input_path: Optional[str] = None
|
||||
method: str = 'SS'
|
||||
spectral_start_index: int = 1
|
||||
handle_outliers: bool = False
|
||||
outlier_method: str = 'iqr'
|
||||
output_dir: Optional[str] = None
|
||||
|
||||
# CSV文件验证参数
|
||||
validate_csv: bool = False
|
||||
min_rows: int = 1
|
||||
min_cols: int = 2
|
||||
check_missing_values: bool = False
|
||||
check_data_types: bool = False
|
||||
wavelength_column: Optional[str] = None
|
||||
|
||||
# MA方法参数
|
||||
ma_window: int = 11
|
||||
# SG方法参数
|
||||
sg_window: int = 15
|
||||
sg_poly: int = 2
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
if not self.input_path:
|
||||
raise ValueError("必须指定输入文件路径(input_path)")
|
||||
|
||||
valid_methods = ['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave']
|
||||
if self.method not in valid_methods:
|
||||
raise ValueError(f"不支持的预处理方法: {self.method}。支持的方法: {valid_methods}")
|
||||
|
||||
valid_outlier_methods = ['iqr', 'isolation-forest', 'lof']
|
||||
if self.outlier_method not in valid_outlier_methods:
|
||||
raise ValueError(f"不支持的异常值检测方法: {self.outlier_method}。支持的方法: {valid_outlier_methods}")
|
||||
|
||||
if not self.output_dir:
|
||||
self.output_dir = './results'
|
||||
|
||||
|
||||
try:
|
||||
import spectral
|
||||
SPECTRAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
SPECTRAL_AVAILABLE = False
|
||||
print("警告: spectral库不可用,将使用内置方法读取数据")
|
||||
|
||||
|
||||
class HyperspectralPreprocessor:
|
||||
"""
|
||||
高光谱数据预处理器
|
||||
|
||||
支持的文件格式:
|
||||
- CSV文件: 需要指定光谱数据的起始列名
|
||||
- ENVI格式: .bil, .bsq, .bip, .dat + .hdr文件 (需要spectral库)
|
||||
"""
|
||||
|
||||
def __init__(self, config: PreprocessingConfig):
|
||||
self.config = config
|
||||
self.data = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.input_format = None # 'csv' 或 'envi'
|
||||
|
||||
def load_data(self, file_path, spectral_start_index=None):
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
参数:
|
||||
file_path: 文件路径
|
||||
spectral_start_index: 对于CSV文件,光谱数据起始列索引(从0开始)
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
if suffix == '.csv':
|
||||
if spectral_start_index is None:
|
||||
raise ValueError("对于CSV文件,必须指定spectral_start_index参数")
|
||||
self.input_format = 'csv'
|
||||
return self._load_csv_data(file_path, spectral_start_index)
|
||||
else:
|
||||
# ENVI格式
|
||||
if SPECTRAL_AVAILABLE and suffix in ['.bil', '.bsq', '.bip', '.dat', '.hdr']:
|
||||
self.input_format = 'envi'
|
||||
return self._load_envi_data(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。请使用CSV或ENVI格式文件,并确保已安装spectral库。")
|
||||
|
||||
def _validate_csv_data(self, df, spectral_start_index):
|
||||
"""验证CSV数据"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# 基本尺寸验证
|
||||
if len(df) < self.config.min_rows:
|
||||
errors.append(f"CSV文件行数 ({len(df)}) 少于最小要求 ({self.config.min_rows})")
|
||||
|
||||
if len(df.columns) < self.config.min_cols:
|
||||
errors.append(f"CSV文件列数 ({len(df.columns)}) 少于最小要求 ({self.config.min_cols})")
|
||||
|
||||
# 验证光谱起始列索引
|
||||
if spectral_start_index < 0 or spectral_start_index >= len(df.columns):
|
||||
errors.append(f"光谱起始列索引 {spectral_start_index} 超出范围 [0, {len(df.columns)-1}]")
|
||||
|
||||
# 检查缺失值
|
||||
if self.config.check_missing_values:
|
||||
missing_count = df.isnull().sum().sum()
|
||||
if missing_count > 0:
|
||||
warnings.append(f"发现 {missing_count} 个缺失值")
|
||||
|
||||
# 检查数据类型一致性
|
||||
if self.config.check_data_types:
|
||||
spectral_cols = df.columns[spectral_start_index:]
|
||||
for col in spectral_cols:
|
||||
try:
|
||||
pd.to_numeric(df[col], errors='coerce')
|
||||
except Exception as e:
|
||||
warnings.append(f"列 '{col}' 包含非数值数据: {e}")
|
||||
|
||||
# 验证波长列
|
||||
if self.config.wavelength_column:
|
||||
if self.config.wavelength_column not in df.columns:
|
||||
errors.append(f"指定的波长列 '{self.config.wavelength_column}' 不存在")
|
||||
else:
|
||||
wavelength_data = df[self.config.wavelength_column]
|
||||
try:
|
||||
numeric_wavelengths = pd.to_numeric(wavelength_data, errors='coerce')
|
||||
if numeric_wavelengths.isnull().any():
|
||||
warnings.append(f"波长列 '{self.config.wavelength_column}' 包含非数值数据")
|
||||
else:
|
||||
# 检查波长是否单调递增
|
||||
if not (numeric_wavelengths.dropna().is_monotonic_increasing or
|
||||
numeric_wavelengths.dropna().is_monotonic_decreasing):
|
||||
warnings.append(f"波长列 '{self.config.wavelength_column}' 不是单调的")
|
||||
except Exception as e:
|
||||
warnings.append(f"波长列 '{self.config.wavelength_column}' 验证失败: {e}")
|
||||
|
||||
# 输出验证结果
|
||||
if errors:
|
||||
error_msg = "CSV验证失败:\n" + "\n".join(f" - {error}" for error in errors)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if warnings:
|
||||
print("CSV验证警告:")
|
||||
for warning in warnings:
|
||||
print(f" - {warning}")
|
||||
|
||||
return True
|
||||
|
||||
def _load_csv_data(self, file_path, spectral_start_index):
|
||||
"""加载CSV格式的高光谱数据"""
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# CSV数据验证
|
||||
if self.config.validate_csv:
|
||||
print("正在验证CSV数据...")
|
||||
self._validate_csv_data(df, spectral_start_index)
|
||||
|
||||
# 验证列索引有效性
|
||||
if spectral_start_index < 0 or spectral_start_index >= len(df.columns):
|
||||
raise ValueError(f"列索引 {spectral_start_index} 超出范围 [0, {len(df.columns)-1}]")
|
||||
|
||||
# 提取光谱数据
|
||||
spectral_data = df.iloc[:, spectral_start_index:].values
|
||||
|
||||
# 提取波长信息(列名转换为数值)
|
||||
try:
|
||||
# 尝试将列名转换为数值作为波长
|
||||
self.wavelengths = pd.to_numeric(df.columns[spectral_start_index:], errors='coerce').values
|
||||
# 检查是否有无效的波长值
|
||||
if np.isnan(self.wavelengths).any():
|
||||
print("警告: 部分列名无法转换为波长值,将使用列索引作为波长")
|
||||
self.wavelengths = np.arange(len(self.wavelengths), dtype=float)
|
||||
except Exception as e:
|
||||
print(f"波长解析失败: {e},将使用列索引作为波长")
|
||||
self.wavelengths = np.arange(spectral_data.shape[1], dtype=float)
|
||||
|
||||
# 保存数据
|
||||
self.data = spectral_data
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"CSV数据加载成功: {self.data_shape}")
|
||||
print(f"光谱起始列索引: {spectral_start_index}")
|
||||
print(f"波段数量: {len(self.wavelengths)}")
|
||||
if len(self.wavelengths) > 0:
|
||||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||||
|
||||
return self.data
|
||||
|
||||
def _load_envi_data(self, file_path):
|
||||
"""加载ENVI格式的高光谱数据"""
|
||||
original_path = Path(file_path)
|
||||
current_path = original_path
|
||||
|
||||
# 如果是.hdr文件,找到对应的数据文件
|
||||
if current_path.suffix.lower() == '.hdr':
|
||||
data_file = current_path.with_suffix('')
|
||||
if not data_file.exists():
|
||||
for ext in ['.dat', '.bil', '.bsq', '.bip']:
|
||||
candidate = current_path.with_suffix(ext)
|
||||
if candidate.exists():
|
||||
data_file = candidate
|
||||
break
|
||||
current_path = data_file
|
||||
|
||||
print(f"使用spectral库加载: {current_path}")
|
||||
|
||||
# 使用spectral打开图像
|
||||
img = spectral.open_image(str(file_path))
|
||||
|
||||
# 读取所有波段的数据
|
||||
self.data = img.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(img, 'metadata') and 'wavelength' in img.metadata:
|
||||
wavelength_data = img.metadata['wavelength']
|
||||
try:
|
||||
if isinstance(wavelength_data, str):
|
||||
wavelength_data = wavelength_data.strip('{}[]')
|
||||
if ',' in wavelength_data:
|
||||
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split(',') if w.strip()])
|
||||
else:
|
||||
self.wavelengths = np.array([float(w.strip()) for w in wavelength_data.split() if w.strip()])
|
||||
elif isinstance(wavelength_data, list):
|
||||
self.wavelengths = np.array([float(w) for w in wavelength_data])
|
||||
else:
|
||||
self.wavelengths = np.array(wavelength_data, dtype=float)
|
||||
|
||||
print(f"spectral解析波长: {len(self.wavelengths)} 个波段")
|
||||
print(f"波长范围: {self.wavelengths[0]:.1f} - {self.wavelengths[-1]:.1f} nm")
|
||||
except Exception as e:
|
||||
print(f"波长解析失败,使用默认值: {e}")
|
||||
self.wavelengths = np.arange(self.data.shape[2])
|
||||
else:
|
||||
self.wavelengths = np.arange(self.data.shape[2])
|
||||
print(f"警告: spectral未找到波长信息,使用默认值: 0-{self.data.shape[2]-1}")
|
||||
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"spectral数据加载成功: {self.data_shape}")
|
||||
print(f"数据类型: {self.data.dtype}")
|
||||
print(f"值范围: [{self.data.min():.3f}, {self.data.max():.3f}]")
|
||||
|
||||
return self.data
|
||||
|
||||
def save_data(self, output_path, data=None, wavelengths=None):
|
||||
"""
|
||||
保存预处理后的数据
|
||||
|
||||
参数:
|
||||
output_path: 输出文件路径
|
||||
data: 要保存的数据 (如果为None,使用self.data)
|
||||
wavelengths: 波长信息 (如果为None,使用self.wavelengths)
|
||||
"""
|
||||
if data is None:
|
||||
data = self.data
|
||||
if wavelengths is None:
|
||||
wavelengths = self.wavelengths
|
||||
|
||||
output_path = Path(output_path)
|
||||
suffix = output_path.suffix.lower()
|
||||
|
||||
if suffix == '.csv':
|
||||
self._save_csv_data(output_path, data, wavelengths)
|
||||
else:
|
||||
# ENVI格式
|
||||
if self.input_format == 'envi' or suffix in ['.bil', '.bsq', '.bip', '.dat']:
|
||||
self._save_envi_data(output_path, data, wavelengths)
|
||||
else:
|
||||
raise ValueError(f"不支持的输出格式: {suffix}")
|
||||
|
||||
def _save_csv_data(self, output_path, data, wavelengths):
|
||||
"""保存为CSV格式"""
|
||||
if wavelengths is not None:
|
||||
# 创建DataFrame,列名使用波长
|
||||
df = pd.DataFrame(data, columns=[f"{w:.1f}" for w in wavelengths])
|
||||
else:
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
df.to_csv(output_path, index=False)
|
||||
print(f"已保存CSV文件: {output_path}")
|
||||
|
||||
def _save_envi_data(self, output_path, data, wavelengths):
|
||||
"""保存为ENVI格式"""
|
||||
data_file = output_path.with_suffix('.dat')
|
||||
hdr_file = output_path.with_suffix('.hdr')
|
||||
|
||||
# 保存二进制数据
|
||||
data.astype('float32').tofile(str(data_file))
|
||||
|
||||
# 创建ENVI头文件
|
||||
if len(data.shape) == 1:
|
||||
lines = 1
|
||||
samples = data.shape[0]
|
||||
bands = 1
|
||||
elif len(data.shape) == 2:
|
||||
lines = data.shape[0]
|
||||
samples = data.shape[1]
|
||||
bands = 1
|
||||
else:
|
||||
lines = data.shape[0]
|
||||
samples = data.shape[1]
|
||||
bands = data.shape[2]
|
||||
|
||||
header_content = f"""ENVI
|
||||
samples = {samples}
|
||||
lines = {lines}
|
||||
bands = {bands}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 4
|
||||
interleave = bip
|
||||
byte order = 0
|
||||
wavelength units = nm
|
||||
"""
|
||||
|
||||
if wavelengths is not None and len(wavelengths) == bands:
|
||||
wavelength_str = ', '.join([f"{w:.1f}" for w in wavelengths])
|
||||
header_content += f"wavelength = {{{wavelength_str}}}\n"
|
||||
|
||||
with open(hdr_file, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
print(f"已保存ENVI文件: {data_file} 和 {hdr_file}")
|
||||
|
||||
def preprocess(self, method, output_path, **kwargs):
|
||||
"""
|
||||
执行预处理并保存结果
|
||||
|
||||
参数:
|
||||
method: 预处理方法名 ('MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave')
|
||||
output_path: 输出文件路径
|
||||
**kwargs: 传递给预处理方法的参数
|
||||
"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先加载数据")
|
||||
|
||||
# 执行预处理
|
||||
processed_data = self._apply_preprocessing(method, **kwargs)
|
||||
|
||||
# 保存结果
|
||||
self.save_data(output_path, processed_data, self.wavelengths)
|
||||
|
||||
return processed_data
|
||||
|
||||
def _apply_preprocessing(self, method, **kwargs):
|
||||
"""应用预处理方法"""
|
||||
method_funcs = {
|
||||
'MMS': self._MMS,
|
||||
'SS': self._SS,
|
||||
'CT': self._CT,
|
||||
'SNV': self._SNV,
|
||||
'MA': self._MA,
|
||||
'SG': self._SG,
|
||||
'D1': self._D1,
|
||||
'D2': self._D2,
|
||||
'DT': self._DT,
|
||||
'MSC': self._MSC,
|
||||
'wave': self._wave
|
||||
}
|
||||
|
||||
if method not in method_funcs:
|
||||
raise ValueError(f"不支持的预处理方法: {method}")
|
||||
|
||||
return method_funcs[method](**kwargs)
|
||||
|
||||
# 预处理方法实现
|
||||
def _MMS(self, **kwargs):
|
||||
"""最大最小值归一化"""
|
||||
print("执行最大最小值归一化...")
|
||||
scaler = MinMaxScaler()
|
||||
if len(self.data.shape) == 2:
|
||||
# CSV格式: (samples, bands)
|
||||
return scaler.fit_transform(self.data)
|
||||
else:
|
||||
# 图像格式: (rows, cols, bands) -> 需要reshape
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
normalized = scaler.fit_transform(reshaped)
|
||||
return normalized.reshape(original_shape)
|
||||
|
||||
def _SS(self, save_path=None, **kwargs):
|
||||
"""标准化"""
|
||||
print("执行标准化...")
|
||||
scaler = StandardScaler()
|
||||
if len(self.data.shape) == 2:
|
||||
result = scaler.fit_transform(self.data)
|
||||
else:
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
result = scaler.fit_transform(reshaped)
|
||||
result = result.reshape(original_shape)
|
||||
|
||||
if save_path:
|
||||
joblib.dump(scaler, save_path)
|
||||
print(f"Scaler参数已保存到: {save_path}")
|
||||
|
||||
return result
|
||||
|
||||
def _CT(self, **kwargs):
|
||||
"""均值中心化"""
|
||||
print("执行均值中心化...")
|
||||
if len(self.data.shape) == 2:
|
||||
# 2D数据: (samples, bands) - 按行计算均值并中心化
|
||||
mean_vals = np.mean(self.data, axis=1, keepdims=True)
|
||||
return self.data - mean_vals
|
||||
else:
|
||||
# 3D数据: (rows, cols, bands) - 按每个像素的光谱计算均值并中心化
|
||||
mean_vals = np.mean(self.data, axis=2, keepdims=True)
|
||||
return self.data - mean_vals
|
||||
|
||||
def _SNV(self, **kwargs):
|
||||
"""标准正态变换"""
|
||||
print("执行标准正态变换...")
|
||||
if len(self.data.shape) != 2:
|
||||
raise ValueError("SNV方法只支持2D数据")
|
||||
|
||||
# 计算每行的均值和标准差
|
||||
data_average = np.mean(self.data, axis=1, keepdims=True)
|
||||
data_std = np.std(self.data, axis=1, keepdims=True)
|
||||
|
||||
# 避免除零错误
|
||||
data_std = np.where(data_std == 0, 1, data_std)
|
||||
|
||||
# 标准化
|
||||
return (self.data - data_average) / data_std
|
||||
|
||||
def _MA(self, WSZ=11, **kwargs):
|
||||
"""移动平均平滑"""
|
||||
print(f"执行移动平均平滑 (窗口大小: {WSZ})...")
|
||||
output_data = deepcopy(self.data)
|
||||
if len(self.data.shape) == 2:
|
||||
for i in range(output_data.shape[0]):
|
||||
out0 = np.convolve(output_data[i], np.ones(WSZ, dtype=int), 'valid') / WSZ
|
||||
r = np.arange(1, WSZ - 1, 2)
|
||||
start = np.cumsum(output_data[i, :WSZ - 1])[::2] / r
|
||||
stop = (np.cumsum(output_data[i, :-WSZ:-1])[::2] / r)[::-1]
|
||||
output_data[i] = np.concatenate((start, out0, stop))
|
||||
else:
|
||||
for i in range(output_data.shape[0]):
|
||||
for j in range(output_data.shape[1]):
|
||||
spectrum = output_data[i, j, :]
|
||||
out0 = np.convolve(spectrum, np.ones(WSZ, dtype=int), 'valid') / WSZ
|
||||
r = np.arange(1, WSZ - 1, 2)
|
||||
start = np.cumsum(spectrum[:WSZ - 1])[::2] / r
|
||||
stop = (np.cumsum(spectrum[:-WSZ:-1])[::2] / r)[::-1]
|
||||
output_data[i, j, :] = np.concatenate((start, out0, stop))
|
||||
return output_data
|
||||
|
||||
def _SG(self, w=15, p=2, **kwargs):
|
||||
"""Savitzky-Golay平滑滤波"""
|
||||
print(f"执行Savitzky-Golay平滑 (窗口: {w}, 阶数: {p})...")
|
||||
if len(self.data.shape) == 2:
|
||||
return signal.savgol_filter(self.data, w, p)
|
||||
else:
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
filtered = signal.savgol_filter(reshaped, w, p)
|
||||
return filtered.reshape(original_shape)
|
||||
|
||||
def _D1(self, **kwargs):
|
||||
"""一阶导数"""
|
||||
print("执行一阶导数...")
|
||||
if len(self.data.shape) == 2:
|
||||
n, p = self.data.shape
|
||||
output_data = np.ones((n, p - 1))
|
||||
for i in range(n):
|
||||
output_data[i] = np.diff(self.data[i])
|
||||
else:
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
n, p = reshaped.shape
|
||||
diff_data = np.ones((n, p - 1))
|
||||
for i in range(n):
|
||||
diff_data[i] = np.diff(reshaped[i])
|
||||
output_data = diff_data.reshape(original_shape[0], original_shape[1], p - 1)
|
||||
|
||||
# 更新波长信息(减少一个波段)
|
||||
if self.wavelengths is not None and len(self.wavelengths) > 1:
|
||||
self.wavelengths = self.wavelengths[:-1]
|
||||
|
||||
return output_data
|
||||
|
||||
def _D2(self, **kwargs):
|
||||
"""二阶导数"""
|
||||
print("执行二阶导数...")
|
||||
if len(self.data.shape) == 2:
|
||||
# 2D数据: (samples, bands)
|
||||
# 计算二阶导数:对原数据求两次差分
|
||||
first_diff = np.diff(self.data, axis=1) # 一阶导数
|
||||
second_diff = np.diff(first_diff, axis=1) # 二阶导数
|
||||
output_data = second_diff
|
||||
elif len(self.data.shape) == 3:
|
||||
# 3D数据: (rows, cols, bands) - 高光谱图像
|
||||
# 对bands维度进行差分
|
||||
first_diff = np.diff(self.data, axis=2) # 一阶导数
|
||||
second_diff = np.diff(first_diff, axis=2) # 二阶导数
|
||||
output_data = second_diff
|
||||
else:
|
||||
raise ValueError("不支持的数据维度")
|
||||
|
||||
# 更新波长信息(减少两个波段)
|
||||
if self.wavelengths is not None and len(self.wavelengths) > 2:
|
||||
self.wavelengths = self.wavelengths[:-2]
|
||||
|
||||
return output_data
|
||||
|
||||
def _DT(self, **kwargs):
|
||||
"""趋势校正"""
|
||||
print("执行趋势校正...")
|
||||
output_data = np.array(self.data)
|
||||
if len(self.data.shape) == 2:
|
||||
length = output_data.shape[1]
|
||||
x = np.asarray(range(length), dtype=np.float32)
|
||||
l = LinearRegression()
|
||||
for i in range(output_data.shape[0]):
|
||||
l.fit(x.reshape(-1, 1), output_data[i].reshape(-1, 1))
|
||||
k = l.coef_
|
||||
b = l.intercept_
|
||||
for j in range(output_data.shape[1]):
|
||||
output_data[i][j] = output_data[i][j] - (j * k + b)
|
||||
else:
|
||||
length = output_data.shape[2]
|
||||
x = np.asarray(range(length), dtype=np.float32)
|
||||
l = LinearRegression()
|
||||
for i in range(output_data.shape[0]):
|
||||
for j in range(output_data.shape[1]):
|
||||
spectrum = output_data[i, j, :]
|
||||
l.fit(x.reshape(-1, 1), spectrum.reshape(-1, 1))
|
||||
k = l.coef_
|
||||
b = l.intercept_
|
||||
for k_idx in range(output_data.shape[2]):
|
||||
output_data[i, j, k_idx] = output_data[i, j, k_idx] - (k_idx * k + b)
|
||||
return output_data
|
||||
|
||||
def _MSC(self, **kwargs):
|
||||
"""多元散射校正"""
|
||||
print("执行多元散射校正...")
|
||||
if len(self.data.shape) == 2:
|
||||
n, p = self.data.shape
|
||||
output_data = np.ones((n, p))
|
||||
mean = np.mean(self.data, axis=0)
|
||||
for i in range(n):
|
||||
y = self.data[i, :]
|
||||
l = LinearRegression()
|
||||
l.fit(mean.reshape(-1, 1), y.reshape(-1, 1))
|
||||
k = l.coef_
|
||||
b = l.intercept_
|
||||
output_data[i, :] = (y - b) / k
|
||||
elif len(self.data.shape) == 3:
|
||||
# 3D数据: (rows, cols, bands) - 高光谱图像
|
||||
rows, cols, bands = self.data.shape
|
||||
output_data = np.zeros((rows, cols, bands), dtype=np.float32)
|
||||
|
||||
# 计算整个图像的平均光谱
|
||||
mean_spectrum = np.mean(self.data, axis=(0, 1)) # 对所有像素求平均
|
||||
|
||||
# 对每个像素进行MSC校正
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
y = self.data[i, j, :]
|
||||
l = LinearRegression()
|
||||
l.fit(mean_spectrum.reshape(-1, 1), y.reshape(-1, 1))
|
||||
k = l.coef_
|
||||
b = l.intercept_
|
||||
# 避免除零错误
|
||||
k = max(k, 1e-10) if k == 0 else k
|
||||
output_data[i, j, :] = (y - b) / k
|
||||
else:
|
||||
raise ValueError("不支持的数据维度")
|
||||
return output_data
|
||||
|
||||
def _wave(self, **kwargs):
|
||||
"""小波变换"""
|
||||
print("执行小波变换...")
|
||||
def wave_single(spectrum):
|
||||
w = pywt.Wavelet('db8')
|
||||
maxlev = pywt.dwt_max_level(len(spectrum), w.dec_len)
|
||||
coeffs = pywt.wavedec(spectrum, 'db8', level=maxlev)
|
||||
threshold = 0.04
|
||||
for i in range(1, len(coeffs)):
|
||||
coeffs[i] = pywt.threshold(coeffs[i], threshold * max(coeffs[i]))
|
||||
return pywt.waverec(coeffs, 'db8')
|
||||
|
||||
if len(self.data.shape) == 2:
|
||||
output_data = None
|
||||
for i in range(self.data.shape[0]):
|
||||
processed = wave_single(self.data[i])
|
||||
if i == 0:
|
||||
output_data = processed
|
||||
else:
|
||||
output_data = np.vstack((output_data, processed))
|
||||
else:
|
||||
original_shape = self.data.shape
|
||||
reshaped = self.data.reshape(-1, original_shape[2])
|
||||
output_data = None
|
||||
for i in range(reshaped.shape[0]):
|
||||
processed = wave_single(reshaped[i])
|
||||
if i == 0:
|
||||
output_data = processed
|
||||
else:
|
||||
output_data = np.vstack((output_data, processed))
|
||||
output_data = output_data.reshape(original_shape)
|
||||
|
||||
return output_data
|
||||
|
||||
|
||||
# ===== 便捷函数 =====
|
||||
|
||||
def preprocess_file(input_file, output_file, method, spectral_start_index=None, **kwargs):
|
||||
"""
|
||||
便捷函数:预处理单个文件
|
||||
|
||||
参数:
|
||||
input_file: 输入文件路径
|
||||
output_file: 输出文件路径
|
||||
method: 预处理方法
|
||||
spectral_start_index: CSV文件的谱段起始列索引(从0开始)
|
||||
**kwargs: 传递给预处理方法的参数
|
||||
"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.load_data(input_file, spectral_start_index)
|
||||
return processor.preprocess(method, output_file, **kwargs)
|
||||
|
||||
|
||||
# 保持向后兼容的函数
|
||||
def MMS(input_spectrum):
|
||||
"""最大最小值归一化 (向后兼容)"""
|
||||
return HyperspectralPreprocessor()._MMS(input_spectrum)
|
||||
|
||||
def SS(input_spectrum, save_path=None):
|
||||
"""标准化 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._SS(save_path=save_path)
|
||||
|
||||
def CT(input_spectrum):
|
||||
"""均值中心化 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._CT()
|
||||
|
||||
def SNV(input_spectrum):
|
||||
"""标准正态变换 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._SNV()
|
||||
|
||||
def MA(input_spectrum, WSZ=11):
|
||||
"""移动平均平滑 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._MA(WSZ=WSZ)
|
||||
|
||||
def SG(input_spectrum, w=15, p=2):
|
||||
"""Savitzky-Golay平滑滤波 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._SG(w=w, p=p)
|
||||
|
||||
def D1(input_spectrum):
|
||||
"""一阶导数 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._D1()
|
||||
|
||||
def D2(input_spectrum):
|
||||
"""二阶导数 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._D2()
|
||||
|
||||
def DT(input_spectrum):
|
||||
"""趋势校正 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._DT()
|
||||
|
||||
def MSC(input_spectrum):
|
||||
"""多元散射校正 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._MSC()
|
||||
|
||||
def wave(input_spectrum):
|
||||
"""小波变换 (向后兼容)"""
|
||||
processor = HyperspectralPreprocessor()
|
||||
processor.data = input_spectrum
|
||||
return processor._wave()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:命令行接口"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='高光谱数据预处理工具',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
1. 处理ENVI格式图像:
|
||||
python Preprocessing.py input.hdr -m SNV -o output_snv.dat
|
||||
|
||||
2. 处理CSV格式数据:
|
||||
python Preprocessing.py input.csv -c 1 -m MSC -o output_msc.csv
|
||||
|
||||
3. 批量处理多种方法:
|
||||
python Preprocessing.py input.csv -c wavelength_400 -m SNV MSC MA -o results/
|
||||
|
||||
支持的预处理方法:
|
||||
MMS: 最大最小值归一化
|
||||
SS: 标准差标准化
|
||||
CT: 中心化
|
||||
SNV: 标准正态变量变换
|
||||
MA: 移动平均
|
||||
SG: Savitzky-Golay滤波
|
||||
D1: 一阶导数
|
||||
D2: 二阶导数
|
||||
DT: 离散变换
|
||||
MSC: 多重散射校正
|
||||
wave: 小波变换
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('input_file', help='输入文件路径 (.hdr高光谱图像 或 .csv文件)')
|
||||
parser.add_argument('-c', '--spectral_col', type=int, default=1,
|
||||
help='CSV文件的谱段起始列索引 (从0开始,默认: 1)')
|
||||
parser.add_argument('-m', '--methods', nargs='+', required=True,
|
||||
choices=['MMS', 'SS', 'CT', 'SNV', 'MA', 'SG', 'D1', 'D2', 'DT', 'MSC', 'wave'],
|
||||
help='预处理方法 (可指定多个)')
|
||||
parser.add_argument('-o', '--output', required=True,
|
||||
help='输出文件路径或目录')
|
||||
|
||||
# 方法特定的参数
|
||||
parser.add_argument('--ma_window', type=int, default=11,
|
||||
help='移动平均窗口大小 (默认: 11)')
|
||||
parser.add_argument('--sg_window', type=int, default=15,
|
||||
help='Savitzky-Golay窗口大小 (默认: 15)')
|
||||
parser.add_argument('--sg_poly', type=int, default=2,
|
||||
help='Savitzky-Golay多项式阶数 (默认: 2)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
print("=" * 60)
|
||||
print("高光谱数据预处理工具")
|
||||
print("=" * 60)
|
||||
print(f"输入文件: {args.input_file}")
|
||||
print(f"预处理方法: {', '.join(args.methods)}")
|
||||
print(f"输出路径: {args.output}")
|
||||
print()
|
||||
|
||||
# 创建处理器
|
||||
processor = HyperspectralPreprocessor()
|
||||
|
||||
# 加载数据
|
||||
print("正在加载数据...")
|
||||
processor.load_data(args.input_file, spectral_start_index=args.spectral_col)
|
||||
|
||||
# 执行预处理
|
||||
for method in args.methods:
|
||||
print(f"\n正在执行 {method} 预处理...")
|
||||
|
||||
# 构建输出路径
|
||||
if len(args.methods) == 1:
|
||||
output_path = args.output
|
||||
else:
|
||||
# 多方法时,在输出目录中创建不同文件名
|
||||
import os
|
||||
if os.path.isdir(args.output):
|
||||
base_name = os.path.splitext(os.path.basename(args.input_file))[0]
|
||||
output_path = os.path.join(args.output, f"{base_name}_{method.lower()}.csv")
|
||||
else:
|
||||
# 如果指定的是文件路径,为每个方法添加后缀
|
||||
base, ext = os.path.splitext(args.output)
|
||||
output_path = f"{base}_{method.lower()}{ext}"
|
||||
|
||||
# 设置方法特定参数
|
||||
kwargs = {}
|
||||
if method == 'MA':
|
||||
kwargs['WSZ'] = args.ma_window
|
||||
elif method == 'SG':
|
||||
kwargs['w'] = args.sg_window
|
||||
kwargs['p'] = args.sg_poly
|
||||
|
||||
# 执行预处理
|
||||
processor.preprocess(method, output_path, **kwargs)
|
||||
print(f"✓ {method} 预处理完成: {output_path}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有预处理任务完成!")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
||||
|
||||
0
preprocessing_method/plot.py
Normal file
0
preprocessing_method/plot.py
Normal file
Reference in New Issue
Block a user