增加模块;增加主调用命令
This commit is contained in:
498
validators.py
Normal file
498
validators.py
Normal file
@ -0,0 +1,498 @@
|
||||
"""
|
||||
高光谱分析工具包验证模块
|
||||
提供输入验证、文件检查和参数校验功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
import spectral
|
||||
import pandas as pd
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class FileValidator:
|
||||
"""文件验证器"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
'hyperspectral': ['.hdr', '.dat', '.bil', '.bsq', '.bip'],
|
||||
'csv': ['.csv'],
|
||||
'xml': ['.xml'],
|
||||
'shp': ['.shp', '.shx', '.dbf'],
|
||||
'image': ['.png', '.jpg', '.jpeg', '.tiff', '.tif']
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_file_exists(file_path: str) -> bool:
|
||||
"""验证文件是否存在"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_file_extension(file_path: str, expected_types: List[str]) -> bool:
|
||||
"""验证文件扩展名"""
|
||||
file_ext = Path(file_path).suffix.lower()
|
||||
valid_extensions = []
|
||||
|
||||
for ext_type in expected_types:
|
||||
if ext_type in FileValidator.SUPPORTED_EXTENSIONS:
|
||||
valid_extensions.extend(FileValidator.SUPPORTED_EXTENSIONS[ext_type])
|
||||
else:
|
||||
valid_extensions.append(ext_type)
|
||||
|
||||
if file_ext not in valid_extensions:
|
||||
raise ValueError(f"不支持的文件类型 {file_ext}。支持的类型: {valid_extensions}")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_hyperspectral_file(file_path: str) -> Dict[str, Any]:
|
||||
"""验证高光谱文件(支持多种格式)"""
|
||||
FileValidator.validate_file_exists(file_path)
|
||||
FileValidator.validate_file_extension(file_path, ['hyperspectral'])
|
||||
|
||||
file_path_obj = Path(file_path)
|
||||
file_ext = file_path_obj.suffix.lower()
|
||||
file_stem = file_path_obj.stem
|
||||
|
||||
try:
|
||||
# 确定要读取的头文件路径
|
||||
if file_ext == '.hdr':
|
||||
# 输入是头文件
|
||||
hdr_path = file_path
|
||||
else:
|
||||
# 输入是数据文件,尝试找到对应的头文件
|
||||
possible_hdr_paths = [
|
||||
file_path_obj.with_suffix('.hdr'), # 将后缀替换为.hdr (data.bip -> data.hdr)
|
||||
file_path_obj.parent / f"{file_path_obj.name}.hdr", # 直接在文件名后添加.hdr (data.bip -> data.bip.hdr)
|
||||
file_path_obj.parent / f"{file_stem}.hdr", # 去掉扩展名的hdr (data.bip -> data.hdr)
|
||||
]
|
||||
|
||||
hdr_path = None
|
||||
for hdr_candidate in possible_hdr_paths:
|
||||
if hdr_candidate.exists():
|
||||
hdr_path = str(hdr_candidate)
|
||||
break
|
||||
|
||||
if hdr_path is None:
|
||||
raise FileNotFoundError(f"找不到对应的HDR头文件。尝试的路径: {[str(p) for p in possible_hdr_paths]}")
|
||||
|
||||
# 验证头文件存在
|
||||
FileValidator.validate_file_exists(hdr_path)
|
||||
|
||||
# 尝试读取文件头
|
||||
img = spectral.open_image(hdr_path)
|
||||
metadata = {
|
||||
'samples': img.shape[0],
|
||||
'lines': img.shape[1],
|
||||
'bands': img.shape[2] if len(img.shape) > 2 else 1,
|
||||
'dtype': str(img.dtype),
|
||||
'interleave': img.metadata.get('interleave', 'unknown'),
|
||||
'hdr_path': hdr_path,
|
||||
'data_path': file_path if file_ext != '.hdr' else str(Path(hdr_path).with_suffix(file_ext))
|
||||
}
|
||||
del img # 释放资源
|
||||
return metadata
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法读取高光谱文件 {file_path}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def validate_csv_file(csv_path: str, required_columns: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""验证CSV文件"""
|
||||
FileValidator.validate_file_exists(csv_path)
|
||||
FileValidator.validate_file_extension(csv_path, ['csv'])
|
||||
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
metadata = {
|
||||
'rows': len(df),
|
||||
'columns': list(df.columns),
|
||||
'dtypes': df.dtypes.to_dict(),
|
||||
'bands': len(df.columns), # 对于CSV,bands表示特征/波段数
|
||||
'samples': len(df), # 对于CSV,samples表示样本数
|
||||
'lines': 1, # CSV被视为单行数据
|
||||
'shape': (len(df), len(df.columns)) # (样本数, 特征数)
|
||||
}
|
||||
|
||||
if required_columns:
|
||||
missing_cols = [col for col in required_columns if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"CSV文件缺少必需列: {missing_cols}")
|
||||
|
||||
return metadata
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法读取CSV文件 {csv_path}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def validate_roi_file(xml_path: str) -> Dict[str, Any]:
|
||||
"""验证ROI文件"""
|
||||
FileValidator.validate_file_exists(xml_path)
|
||||
FileValidator.validate_file_extension(xml_path, ['xml'])
|
||||
|
||||
try:
|
||||
import xml.etree.ElementTree as ET
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# 解析ROI信息
|
||||
regions = []
|
||||
for region in root.findall('.//Region'):
|
||||
region_info = {
|
||||
'id': region.get('id'),
|
||||
'name': region.get('name', ''),
|
||||
'type': region.get('type', ''),
|
||||
'color': region.get('color', '')
|
||||
}
|
||||
regions.append(region_info)
|
||||
|
||||
return {
|
||||
'regions_count': len(regions),
|
||||
'regions': regions
|
||||
}
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法读取ROI文件 {xml_path}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def validate_output_directory(output_dir: str) -> bool:
|
||||
"""验证输出目录"""
|
||||
try:
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
# 测试写入权限
|
||||
test_file = Path(output_dir) / '.test_write'
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
return True
|
||||
except Exception as e:
|
||||
raise PermissionError(f"无法创建或写入输出目录 {output_dir}: {e}")
|
||||
|
||||
|
||||
class DataValidator:
|
||||
"""数据验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate_array_shape(data: np.ndarray, expected_shape: Optional[Tuple] = None,
|
||||
min_dims: int = 1, max_dims: int = 4) -> bool:
|
||||
"""验证数组形状"""
|
||||
if not isinstance(data, np.ndarray):
|
||||
raise TypeError(f"期望numpy数组,得到 {type(data)}")
|
||||
|
||||
if data.ndim < min_dims or data.ndim > max_dims:
|
||||
raise ValueError(f"数组维度 {data.ndim} 超出范围 [{min_dims}, {max_dims}]")
|
||||
|
||||
if expected_shape and data.shape != expected_shape:
|
||||
raise ValueError(f"数组形状 {data.shape} 不匹配期望形状 {expected_shape}")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_numeric_range(data: np.ndarray, min_val: Optional[float] = None,
|
||||
max_val: Optional[float] = None) -> bool:
|
||||
"""验证数值范围"""
|
||||
if min_val is not None and np.min(data) < min_val:
|
||||
raise ValueError(f"数据最小值 {np.min(data)} 小于允许的最小值 {min_val}")
|
||||
|
||||
if max_val is not None and np.max(data) > max_val:
|
||||
raise ValueError(f"数据最大值 {np.max(data)} 大于允许的最大值 {max_val}")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_no_nan_inf(data: np.ndarray) -> bool:
|
||||
"""验证无NaN或Inf值"""
|
||||
if np.any(np.isnan(data)):
|
||||
raise ValueError("数据包含NaN值")
|
||||
|
||||
if np.any(np.isinf(data)):
|
||||
raise ValueError("数据包含Inf值")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_band_index(band_index: int, max_bands: int) -> bool:
|
||||
"""验证波段索引"""
|
||||
if not isinstance(band_index, int) or band_index < 0:
|
||||
raise ValueError(f"波段索引必须是非负整数,得到 {band_index}")
|
||||
|
||||
if band_index >= max_bands:
|
||||
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {max_bands-1}]")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ParameterValidator:
|
||||
"""参数验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate_method_choice(method: str, valid_methods: List[str]) -> bool:
|
||||
"""验证方法选择"""
|
||||
if method not in valid_methods:
|
||||
raise ValueError(f"不支持的方法 '{method}'。有效方法: {valid_methods}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_numeric_parameter(value: Any, param_name: str,
|
||||
min_val: Optional[float] = None,
|
||||
max_val: Optional[float] = None,
|
||||
value_type: type = int) -> bool:
|
||||
"""验证数值参数"""
|
||||
try:
|
||||
if not isinstance(value, value_type):
|
||||
value = value_type(value)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError(f"参数 {param_name} 必须是 {value_type.__name__} 类型")
|
||||
|
||||
if min_val is not None and value < min_val:
|
||||
raise ValueError(f"参数 {param_name} 必须 >= {min_val}")
|
||||
|
||||
if max_val is not None and value > max_val:
|
||||
raise ValueError(f"参数 {param_name} 必须 <= {max_val}")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_probability(value: float, param_name: str = "probability") -> bool:
|
||||
"""验证概率参数"""
|
||||
return ParameterValidator.validate_numeric_parameter(
|
||||
value, param_name, min_val=0.0, max_val=1.0, value_type=float
|
||||
)
|
||||
|
||||
|
||||
class TaskValidator:
|
||||
"""任务特定验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate_dim_reduction_params(method: str, n_components: int, data_shape: Tuple) -> bool:
|
||||
"""验证降维参数"""
|
||||
n_samples, n_features = data_shape[:2]
|
||||
|
||||
if n_components >= n_features:
|
||||
raise ValueError(f"降维组件数 {n_components} 不能 >= 原始特征数 {n_features}")
|
||||
|
||||
if n_components >= n_samples:
|
||||
raise ValueError(f"降维组件数 {n_components} 不能 >= 样本数 {n_samples}")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_clustering_params(method: str, n_clusters: int, data_shape: Tuple) -> bool:
|
||||
"""验证聚类参数"""
|
||||
n_samples = data_shape[0]
|
||||
|
||||
if n_clusters >= n_samples:
|
||||
raise ValueError(f"聚类数 {n_clusters} 不能 >= 样本数 {n_samples}")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_segmentation_params(method: str, threshold: Optional[float],
|
||||
band_index: int, max_bands: int) -> bool:
|
||||
"""验证分割参数"""
|
||||
DataValidator.validate_band_index(band_index, max_bands)
|
||||
|
||||
if method == 'fixed' and threshold is None:
|
||||
raise ValueError("固定阈值分割方法需要指定threshold参数")
|
||||
|
||||
if threshold is not None:
|
||||
ParameterValidator.validate_probability(threshold, "threshold")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def detect_file_type(file_path: str) -> str:
|
||||
"""
|
||||
检测文件类型
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件类型: 'csv', 'hyperspectral', 或 'unknown'
|
||||
"""
|
||||
file_ext = Path(file_path).suffix.lower()
|
||||
|
||||
# 明确是CSV文件
|
||||
if file_ext == '.csv':
|
||||
return 'csv'
|
||||
|
||||
# 可能是高光谱文件
|
||||
if file_ext in ['.hdr', '.dat', '.bil', '.bsq', '.bip']:
|
||||
return 'hyperspectral'
|
||||
|
||||
# 尝试读取文件内容来判断
|
||||
try:
|
||||
# 检查是否是纯文本文件(可能是CSV)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
first_line = f.readline().strip()
|
||||
# 如果第一行包含逗号,很可能是CSV
|
||||
if ',' in first_line:
|
||||
# 检查是否都是数字(CSV数据)
|
||||
try:
|
||||
values = [float(x.strip()) for x in first_line.split(',')]
|
||||
return 'csv'
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 检查是否是二进制文件(高光谱数据)
|
||||
with open(file_path, 'rb') as f:
|
||||
# 读取前几个字节
|
||||
header = f.read(100)
|
||||
# 如果包含很多null字节或非ASCII字符,可能是二进制
|
||||
if any(b < 32 or b > 126 for b in header if b != 0):
|
||||
return 'hyperspectral'
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return 'unknown'
|
||||
|
||||
|
||||
def validate_input_file(file_path: str, allowed_types: List[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
根据文件类型自动选择验证方式
|
||||
|
||||
Args:
|
||||
file_path: 输入文件路径
|
||||
allowed_types: 允许的文件类型列表,默认允许所有类型
|
||||
|
||||
Returns:
|
||||
文件信息字典
|
||||
"""
|
||||
file_type = detect_file_type(file_path)
|
||||
|
||||
if allowed_types and file_type not in allowed_types:
|
||||
raise ValueError(f"文件类型 '{file_type}' 不被允许。支持的类型: {allowed_types}")
|
||||
|
||||
if file_type == 'csv':
|
||||
return FileValidator.validate_csv_file(file_path)
|
||||
elif file_type == 'hyperspectral':
|
||||
return FileValidator.validate_hyperspectral_file(file_path)
|
||||
else:
|
||||
raise ValueError(f"无法识别的文件类型: {file_path}")
|
||||
|
||||
|
||||
def validate_task_inputs(task_name: str, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
统一验证任务输入
|
||||
|
||||
Args:
|
||||
task_name: 任务名称
|
||||
**kwargs: 任务参数
|
||||
|
||||
Returns:
|
||||
验证结果字典
|
||||
"""
|
||||
results = {'valid': True, 'warnings': [], 'errors': []}
|
||||
|
||||
try:
|
||||
# 文件验证
|
||||
if 'input' in kwargs:
|
||||
# 根据任务类型决定允许的文件类型
|
||||
if task_name == 'spectral-index':
|
||||
# spectral-index任务同时支持CSV和高光谱文件
|
||||
allowed_types = ['csv', 'hyperspectral']
|
||||
elif task_name in ['feature-selection', 'regression', 'preprocessing']:
|
||||
# 这些任务主要接受CSV文件,但也可以接受其他格式
|
||||
allowed_types = ['csv', 'hyperspectral']
|
||||
else:
|
||||
# 其他任务主要需要高光谱文件
|
||||
allowed_types = ['hyperspectral', 'csv'] # 保持向后兼容性
|
||||
|
||||
file_info = validate_input_file(kwargs['input'], allowed_types)
|
||||
results['file_info'] = file_info
|
||||
|
||||
if 'roi_file' in kwargs:
|
||||
roi_info = FileValidator.validate_roi_file(kwargs['roi_file'])
|
||||
results['roi_info'] = roi_info
|
||||
|
||||
if 'output_dir' in kwargs:
|
||||
FileValidator.validate_output_directory(kwargs['output_dir'])
|
||||
|
||||
# 任务特定验证
|
||||
if task_name == 'dim-reduction':
|
||||
# 获取数据维度信息
|
||||
file_info = results.get('file_info', {})
|
||||
if file_info.get('bands', 0) > 0:
|
||||
# 构造适合降维的数据形状
|
||||
if 'shape' in file_info:
|
||||
shape = file_info['shape']
|
||||
else:
|
||||
# 从其他字段推断
|
||||
samples = file_info.get('samples', 100)
|
||||
bands = file_info.get('bands', 10)
|
||||
shape = (samples, bands)
|
||||
|
||||
TaskValidator.validate_dim_reduction_params(
|
||||
kwargs.get('method', 'pca'),
|
||||
kwargs.get('n_components', 3),
|
||||
shape
|
||||
)
|
||||
|
||||
elif task_name == 'clustering':
|
||||
# 获取数据维度信息
|
||||
file_info = results.get('file_info', {})
|
||||
if file_info.get('bands', 0) > 0:
|
||||
# 构造适合聚类的数据形状
|
||||
if 'shape' in file_info:
|
||||
shape = file_info['shape']
|
||||
else:
|
||||
# 从其他字段推断
|
||||
samples = file_info.get('samples', 100)
|
||||
bands = file_info.get('bands', 10)
|
||||
shape = (samples, bands)
|
||||
|
||||
TaskValidator.validate_clustering_params(
|
||||
kwargs.get('method', 'kmeans'),
|
||||
kwargs.get('n_clusters', 5),
|
||||
shape
|
||||
)
|
||||
elif task_name == 'segmentation':
|
||||
max_bands = results.get('file_info', {}).get('bands', 100)
|
||||
TaskValidator.validate_segmentation_params(
|
||||
kwargs.get('method', 'otsu'),
|
||||
kwargs.get('threshold'),
|
||||
kwargs.get('band_index', 0),
|
||||
max_bands
|
||||
)
|
||||
elif task_name == 'spectral-index':
|
||||
# 光谱指数计算验证
|
||||
file_info = results.get('file_info', {})
|
||||
|
||||
# 验证波段数
|
||||
max_bands = file_info.get('bands', 0)
|
||||
if max_bands == 0:
|
||||
raise ValueError("无法确定输入文件的波段数量")
|
||||
|
||||
# 如果指定了波段索引,验证其有效性
|
||||
if 'band_index' in kwargs:
|
||||
band_index = kwargs['band_index']
|
||||
if not isinstance(band_index, int) or band_index < 0 or band_index >= max_bands:
|
||||
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {max_bands-1}]")
|
||||
|
||||
# 验证光谱指数列表(如果提供)
|
||||
if 'indices' in kwargs and kwargs['indices'] is not None:
|
||||
indices = kwargs['indices']
|
||||
if isinstance(indices, str):
|
||||
indices = [indices]
|
||||
if not isinstance(indices, list):
|
||||
raise ValueError("光谱指数参数必须是字符串或字符串列表")
|
||||
|
||||
elif task_name == 'glcm':
|
||||
# GLCM纹理特征提取验证
|
||||
band_index = kwargs.get('band_index', 25)
|
||||
max_bands = results.get('file_info', {}).get('bands', 100)
|
||||
if band_index >= max_bands:
|
||||
raise ValueError(f"GLCM波段索引 {band_index} 超出范围 [0, {max_bands-1}]")
|
||||
|
||||
slide_window = kwargs.get('slide_window', 7)
|
||||
if slide_window <= 0 or slide_window % 2 == 0:
|
||||
raise ValueError("GLCM滑动窗口大小必须为正奇数")
|
||||
|
||||
except Exception as e:
|
||||
results['valid'] = False
|
||||
results['errors'].append(str(e))
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user