Files
HSI/validators.py

499 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
高光谱分析工具包验证模块
提供输入验证、文件检查和参数校验功能
"""
import os
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, List, Dict, Any
import spectral
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
class FileValidator:
"""文件验证器"""
SUPPORTED_EXTENSIONS = {
'hyperspectral': ['.hdr', '.dat', '.bil', '.bsq', '.bip'],
'csv': ['.csv'],
'xml': ['.xml'],
'shp': ['.shp', '.shx', '.dbf'],
'image': ['.png', '.jpg', '.jpeg', '.tiff', '.tif']
}
@staticmethod
def validate_file_exists(file_path: str) -> bool:
"""验证文件是否存在"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
return True
@staticmethod
def validate_file_extension(file_path: str, expected_types: List[str]) -> bool:
"""验证文件扩展名"""
file_ext = Path(file_path).suffix.lower()
valid_extensions = []
for ext_type in expected_types:
if ext_type in FileValidator.SUPPORTED_EXTENSIONS:
valid_extensions.extend(FileValidator.SUPPORTED_EXTENSIONS[ext_type])
else:
valid_extensions.append(ext_type)
if file_ext not in valid_extensions:
raise ValueError(f"不支持的文件类型 {file_ext}。支持的类型: {valid_extensions}")
return True
@staticmethod
def validate_hyperspectral_file(file_path: str) -> Dict[str, Any]:
"""验证高光谱文件(支持多种格式)"""
FileValidator.validate_file_exists(file_path)
FileValidator.validate_file_extension(file_path, ['hyperspectral'])
file_path_obj = Path(file_path)
file_ext = file_path_obj.suffix.lower()
file_stem = file_path_obj.stem
try:
# 确定要读取的头文件路径
if file_ext == '.hdr':
# 输入是头文件
hdr_path = file_path
else:
# 输入是数据文件,尝试找到对应的头文件
possible_hdr_paths = [
file_path_obj.with_suffix('.hdr'), # 将后缀替换为.hdr (data.bip -> data.hdr)
file_path_obj.parent / f"{file_path_obj.name}.hdr", # 直接在文件名后添加.hdr (data.bip -> data.bip.hdr)
file_path_obj.parent / f"{file_stem}.hdr", # 去掉扩展名的hdr (data.bip -> data.hdr)
]
hdr_path = None
for hdr_candidate in possible_hdr_paths:
if hdr_candidate.exists():
hdr_path = str(hdr_candidate)
break
if hdr_path is None:
raise FileNotFoundError(f"找不到对应的HDR头文件。尝试的路径: {[str(p) for p in possible_hdr_paths]}")
# 验证头文件存在
FileValidator.validate_file_exists(hdr_path)
# 尝试读取文件头
img = spectral.open_image(hdr_path)
metadata = {
'samples': img.shape[0],
'lines': img.shape[1],
'bands': img.shape[2] if len(img.shape) > 2 else 1,
'dtype': str(img.dtype),
'interleave': img.metadata.get('interleave', 'unknown'),
'hdr_path': hdr_path,
'data_path': file_path if file_ext != '.hdr' else str(Path(hdr_path).with_suffix(file_ext))
}
del img # 释放资源
return metadata
except Exception as e:
raise ValueError(f"无法读取高光谱文件 {file_path}: {e}")
@staticmethod
def validate_csv_file(csv_path: str, required_columns: Optional[List[str]] = None) -> Dict[str, Any]:
"""验证CSV文件"""
FileValidator.validate_file_exists(csv_path)
FileValidator.validate_file_extension(csv_path, ['csv'])
try:
df = pd.read_csv(csv_path)
metadata = {
'rows': len(df),
'columns': list(df.columns),
'dtypes': df.dtypes.to_dict(),
'bands': len(df.columns), # 对于CSVbands表示特征/波段数
'samples': len(df), # 对于CSVsamples表示样本数
'lines': 1, # CSV被视为单行数据
'shape': (len(df), len(df.columns)) # (样本数, 特征数)
}
if required_columns:
missing_cols = [col for col in required_columns if col not in df.columns]
if missing_cols:
raise ValueError(f"CSV文件缺少必需列: {missing_cols}")
return metadata
except Exception as e:
raise ValueError(f"无法读取CSV文件 {csv_path}: {e}")
@staticmethod
def validate_roi_file(xml_path: str) -> Dict[str, Any]:
"""验证ROI文件"""
FileValidator.validate_file_exists(xml_path)
FileValidator.validate_file_extension(xml_path, ['xml'])
try:
import xml.etree.ElementTree as ET
tree = ET.parse(xml_path)
root = tree.getroot()
# 解析ROI信息
regions = []
for region in root.findall('.//Region'):
region_info = {
'id': region.get('id'),
'name': region.get('name', ''),
'type': region.get('type', ''),
'color': region.get('color', '')
}
regions.append(region_info)
return {
'regions_count': len(regions),
'regions': regions
}
except Exception as e:
raise ValueError(f"无法读取ROI文件 {xml_path}: {e}")
@staticmethod
def validate_output_directory(output_dir: str) -> bool:
"""验证输出目录"""
try:
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 测试写入权限
test_file = Path(output_dir) / '.test_write'
test_file.touch()
test_file.unlink()
return True
except Exception as e:
raise PermissionError(f"无法创建或写入输出目录 {output_dir}: {e}")
class DataValidator:
"""数据验证器"""
@staticmethod
def validate_array_shape(data: np.ndarray, expected_shape: Optional[Tuple] = None,
min_dims: int = 1, max_dims: int = 4) -> bool:
"""验证数组形状"""
if not isinstance(data, np.ndarray):
raise TypeError(f"期望numpy数组得到 {type(data)}")
if data.ndim < min_dims or data.ndim > max_dims:
raise ValueError(f"数组维度 {data.ndim} 超出范围 [{min_dims}, {max_dims}]")
if expected_shape and data.shape != expected_shape:
raise ValueError(f"数组形状 {data.shape} 不匹配期望形状 {expected_shape}")
return True
@staticmethod
def validate_numeric_range(data: np.ndarray, min_val: Optional[float] = None,
max_val: Optional[float] = None) -> bool:
"""验证数值范围"""
if min_val is not None and np.min(data) < min_val:
raise ValueError(f"数据最小值 {np.min(data)} 小于允许的最小值 {min_val}")
if max_val is not None and np.max(data) > max_val:
raise ValueError(f"数据最大值 {np.max(data)} 大于允许的最大值 {max_val}")
return True
@staticmethod
def validate_no_nan_inf(data: np.ndarray) -> bool:
"""验证无NaN或Inf值"""
if np.any(np.isnan(data)):
raise ValueError("数据包含NaN值")
if np.any(np.isinf(data)):
raise ValueError("数据包含Inf值")
return True
@staticmethod
def validate_band_index(band_index: int, max_bands: int) -> bool:
"""验证波段索引"""
if not isinstance(band_index, int) or band_index < 0:
raise ValueError(f"波段索引必须是非负整数,得到 {band_index}")
if band_index >= max_bands:
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {max_bands-1}]")
return True
class ParameterValidator:
"""参数验证器"""
@staticmethod
def validate_method_choice(method: str, valid_methods: List[str]) -> bool:
"""验证方法选择"""
if method not in valid_methods:
raise ValueError(f"不支持的方法 '{method}'。有效方法: {valid_methods}")
return True
@staticmethod
def validate_numeric_parameter(value: Any, param_name: str,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
value_type: type = int) -> bool:
"""验证数值参数"""
try:
if not isinstance(value, value_type):
value = value_type(value)
except (ValueError, TypeError):
raise TypeError(f"参数 {param_name} 必须是 {value_type.__name__} 类型")
if min_val is not None and value < min_val:
raise ValueError(f"参数 {param_name} 必须 >= {min_val}")
if max_val is not None and value > max_val:
raise ValueError(f"参数 {param_name} 必须 <= {max_val}")
return True
@staticmethod
def validate_probability(value: float, param_name: str = "probability") -> bool:
"""验证概率参数"""
return ParameterValidator.validate_numeric_parameter(
value, param_name, min_val=0.0, max_val=1.0, value_type=float
)
class TaskValidator:
"""任务特定验证器"""
@staticmethod
def validate_dim_reduction_params(method: str, n_components: int, data_shape: Tuple) -> bool:
"""验证降维参数"""
n_samples, n_features = data_shape[:2]
if n_components >= n_features:
raise ValueError(f"降维组件数 {n_components} 不能 >= 原始特征数 {n_features}")
if n_components >= n_samples:
raise ValueError(f"降维组件数 {n_components} 不能 >= 样本数 {n_samples}")
return True
@staticmethod
def validate_clustering_params(method: str, n_clusters: int, data_shape: Tuple) -> bool:
"""验证聚类参数"""
n_samples = data_shape[0]
if n_clusters >= n_samples:
raise ValueError(f"聚类数 {n_clusters} 不能 >= 样本数 {n_samples}")
return True
@staticmethod
def validate_segmentation_params(method: str, threshold: Optional[float],
band_index: int, max_bands: int) -> bool:
"""验证分割参数"""
DataValidator.validate_band_index(band_index, max_bands)
if method == 'fixed' and threshold is None:
raise ValueError("固定阈值分割方法需要指定threshold参数")
if threshold is not None:
ParameterValidator.validate_probability(threshold, "threshold")
return True
def detect_file_type(file_path: str) -> str:
"""
检测文件类型
Args:
file_path: 文件路径
Returns:
文件类型: 'csv', 'hyperspectral', 或 'unknown'
"""
file_ext = Path(file_path).suffix.lower()
# 明确是CSV文件
if file_ext == '.csv':
return 'csv'
# 可能是高光谱文件
if file_ext in ['.hdr', '.dat', '.bil', '.bsq', '.bip']:
return 'hyperspectral'
# 尝试读取文件内容来判断
try:
# 检查是否是纯文本文件可能是CSV
with open(file_path, 'r', encoding='utf-8') as f:
first_line = f.readline().strip()
# 如果第一行包含逗号很可能是CSV
if ',' in first_line:
# 检查是否都是数字CSV数据
try:
values = [float(x.strip()) for x in first_line.split(',')]
return 'csv'
except ValueError:
pass
# 检查是否是二进制文件(高光谱数据)
with open(file_path, 'rb') as f:
# 读取前几个字节
header = f.read(100)
# 如果包含很多null字节或非ASCII字符可能是二进制
if any(b < 32 or b > 126 for b in header if b != 0):
return 'hyperspectral'
except Exception:
pass
return 'unknown'
def validate_input_file(file_path: str, allowed_types: List[str] = None) -> Dict[str, Any]:
"""
根据文件类型自动选择验证方式
Args:
file_path: 输入文件路径
allowed_types: 允许的文件类型列表,默认允许所有类型
Returns:
文件信息字典
"""
file_type = detect_file_type(file_path)
if allowed_types and file_type not in allowed_types:
raise ValueError(f"文件类型 '{file_type}' 不被允许。支持的类型: {allowed_types}")
if file_type == 'csv':
return FileValidator.validate_csv_file(file_path)
elif file_type == 'hyperspectral':
return FileValidator.validate_hyperspectral_file(file_path)
else:
raise ValueError(f"无法识别的文件类型: {file_path}")
def validate_task_inputs(task_name: str, **kwargs) -> Dict[str, Any]:
"""
统一验证任务输入
Args:
task_name: 任务名称
**kwargs: 任务参数
Returns:
验证结果字典
"""
results = {'valid': True, 'warnings': [], 'errors': []}
try:
# 文件验证
if 'input' in kwargs:
# 根据任务类型决定允许的文件类型
if task_name == 'spectral-index':
# spectral-index任务同时支持CSV和高光谱文件
allowed_types = ['csv', 'hyperspectral']
elif task_name in ['feature-selection', 'regression', 'preprocessing']:
# 这些任务主要接受CSV文件但也可以接受其他格式
allowed_types = ['csv', 'hyperspectral']
else:
# 其他任务主要需要高光谱文件
allowed_types = ['hyperspectral', 'csv'] # 保持向后兼容性
file_info = validate_input_file(kwargs['input'], allowed_types)
results['file_info'] = file_info
if 'roi_file' in kwargs:
roi_info = FileValidator.validate_roi_file(kwargs['roi_file'])
results['roi_info'] = roi_info
if 'output_dir' in kwargs:
FileValidator.validate_output_directory(kwargs['output_dir'])
# 任务特定验证
if task_name == 'dim-reduction':
# 获取数据维度信息
file_info = results.get('file_info', {})
if file_info.get('bands', 0) > 0:
# 构造适合降维的数据形状
if 'shape' in file_info:
shape = file_info['shape']
else:
# 从其他字段推断
samples = file_info.get('samples', 100)
bands = file_info.get('bands', 10)
shape = (samples, bands)
TaskValidator.validate_dim_reduction_params(
kwargs.get('method', 'pca'),
kwargs.get('n_components', 3),
shape
)
elif task_name == 'clustering':
# 获取数据维度信息
file_info = results.get('file_info', {})
if file_info.get('bands', 0) > 0:
# 构造适合聚类的数据形状
if 'shape' in file_info:
shape = file_info['shape']
else:
# 从其他字段推断
samples = file_info.get('samples', 100)
bands = file_info.get('bands', 10)
shape = (samples, bands)
TaskValidator.validate_clustering_params(
kwargs.get('method', 'kmeans'),
kwargs.get('n_clusters', 5),
shape
)
elif task_name == 'segmentation':
max_bands = results.get('file_info', {}).get('bands', 100)
TaskValidator.validate_segmentation_params(
kwargs.get('method', 'otsu'),
kwargs.get('threshold'),
kwargs.get('band_index', 0),
max_bands
)
elif task_name == 'spectral-index':
# 光谱指数计算验证
file_info = results.get('file_info', {})
# 验证波段数
max_bands = file_info.get('bands', 0)
if max_bands == 0:
raise ValueError("无法确定输入文件的波段数量")
# 如果指定了波段索引,验证其有效性
if 'band_index' in kwargs:
band_index = kwargs['band_index']
if not isinstance(band_index, int) or band_index < 0 or band_index >= max_bands:
raise ValueError(f"波段索引 {band_index} 超出范围 [0, {max_bands-1}]")
# 验证光谱指数列表(如果提供)
if 'indices' in kwargs and kwargs['indices'] is not None:
indices = kwargs['indices']
if isinstance(indices, str):
indices = [indices]
if not isinstance(indices, list):
raise ValueError("光谱指数参数必须是字符串或字符串列表")
elif task_name == 'glcm':
# GLCM纹理特征提取验证
band_index = kwargs.get('band_index', 25)
max_bands = results.get('file_info', {}).get('bands', 100)
if band_index >= max_bands:
raise ValueError(f"GLCM波段索引 {band_index} 超出范围 [0, {max_bands-1}]")
slide_window = kwargs.get('slide_window', 7)
if slide_window <= 0 or slide_window % 2 == 0:
raise ValueError("GLCM滑动窗口大小必须为正奇数")
except Exception as e:
results['valid'] = False
results['errors'].append(str(e))
return results