""" 回归预测工具包 支持使用训练好的回归模型对高光谱图像进行预测 包含图像读取、遮罩处理、批量预测、结果可视化等功能 """ import os import numpy as np import matplotlib.pyplot as plt from pathlib import Path import glob import joblib import json import warnings import argparse from typing import Optional, List, Dict, Any, Union, Tuple from dataclasses import dataclass, field import time # 导入回归分析器和所有必要的类(为了pickle兼容性) try: # 当作为模块导入时使用相对导入 from .regression import ( RegressionAnalyzer, ExtremeLearningMachine, GeneralizedAdditiveModel, LSTMRegressor, GRURegressor ) except ImportError: # 当直接运行脚本时使用绝对导入 import sys import os current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.insert(0, current_dir) from regression import ( RegressionAnalyzer, ExtremeLearningMachine, GeneralizedAdditiveModel, LSTMRegressor, GRURegressor ) warnings.filterwarnings('ignore') @dataclass class PredictionConfig: """预测配置类""" image_path: str = "" # 高光谱图像路径 mask_path: Optional[str] = None # 遮罩文件路径 model_path: Union[str, List[str]] = "" # 模型文件路径(单文件或列表) output_dir: str = "prediction_results" # 输出目录 use_mask: bool = True # 是否使用遮罩 batch_mode: bool = False # 是否批量处理 colormap: str = 'viridis' # 颜色映射 dpi: int = 300 # 输出图像DPI save_individual: bool = True # 是否保存单个预测结果 @dataclass class ImageReader: """高光谱图像读取器""" def __init__(self): pass def read_envi_header(self, header_path: Path) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int], str, int]: """读取ENVI头文件获取图像尺寸和格式信息""" rows, cols, bands = None, None, None data_type = None interleave = 'bsq' # 默认值 byte_order = 0 # 默认小端字节序 try: with open(header_path, 'r', encoding='utf-8') as f: content = f.read() # 解析关键参数 lines = content.split('\n') for line in lines: line = line.strip() if '=' in line: key, value = line.split('=', 1) key = key.strip() value = value.strip() if key.lower() == 'samples': cols = int(value) elif key.lower() == 'lines': rows = int(value) elif key.lower() == 'bands': bands = int(value) elif key.lower() == 'data type': data_type = int(value) elif key.lower() == 'interleave': interleave = value.lower() # bsq, bil, bip elif key.lower() == 'byte order': byte_order = int(value) if rows and cols and bands: print(f"从头文件读取: {rows}x{cols}x{bands}, interleave={interleave}, data_type={data_type}") except Exception as e: print(f"读取头文件时出错: {e}") return rows, cols, bands, data_type, interleave, byte_order def get_numpy_dtype(self, envi_data_type: int) -> np.dtype: """将ENVI数据类型转换为numpy数据类型""" dtype_map = { 1: np.uint8, # 8-bit byte 2: np.int16, # 16-bit signed integer 3: np.int32, # 32-bit signed long integer 4: np.float32, # 32-bit floating point 5: np.float64, # 64-bit double precision floating point 12: np.uint16, # 16-bit unsigned integer 13: np.uint32, # 32-bit unsigned long integer 14: np.int64, # 64-bit signed long integer 15: np.uint64, # 64-bit unsigned long integer } return dtype_map.get(envi_data_type, np.float32) def read_envi_file(self, data_path: Path, rows: Optional[int] = None, cols: Optional[int] = None, bands: Optional[int] = None) -> Optional[np.ndarray]: """ 读取ENVI格式文件,支持多种读取方式 优先使用spectral库,如果不可用则使用自定义方法 """ """ 读取ENVI格式文件(支持BSQ, BIL, BIP格式) 参数: data_path: 数据文件路径 rows: 图像行数(高度) cols: 图像列数(宽度) bands: 波段数 返回: 图像数据数组,形状为(rows, cols, bands) """ # 尝试使用spectral库读取ENVI文件(如果可用) try: import spectral print("使用spectral库读取ENVI文件") # spectral库可以自动处理头文件和数据文件的匹配 img = spectral.open_image(str(data_path)) data = img.load() print(f"spectral库成功读取,数据形状: {data.shape}, 数据类型: {data.dtype}") return data except ImportError: print("spectral库不可用,使用自定义方法读取") except Exception as e: print(f"spectral库读取失败: {e},使用自定义方法") # 使用自定义方法读取 # 尝试读取头文件获取尺寸信息 # 首先检查输入路径是否已经是头文件 if data_path.suffix.lower() == '.hdr': header_path = data_path else: # 输入是数据文件,尝试寻找对应的头文件 # 首先尝试直接替换扩展名为.hdr header_path = data_path.with_suffix('.hdr') # 如果不存在,尝试其他可能的命名方式 if not header_path.exists(): # 对于.bil文件,头文件可能是.bil.hdr alt_header_path = data_path.parent / f"{data_path.name}.hdr" if alt_header_path.exists(): header_path = alt_header_path else: # 或者文件名.hdr(去掉扩展名) base_header_path = data_path.parent / f"{data_path.stem}.hdr" if base_header_path.exists(): header_path = base_header_path if header_path.exists(): print(f"正在读取头文件: {header_path}") rows, cols, bands, data_type, interleave, byte_order = self.read_envi_header(header_path) if data_type is None: data_type = 4 # 默认float32 dtype = self.get_numpy_dtype(data_type) # 确定数据文件路径 if data_path.suffix.lower() == '.hdr': # 输入是头文件,对应的数据文件是去掉.hdr扩展名的文件 actual_data_path = data_path.with_suffix('') print(f"输入是头文件,计算出的数据文件路径: {actual_data_path}") # 检查常见的数据文件扩展名 if not actual_data_path.exists(): # 尝试添加.bil扩展名 bil_path = actual_data_path.with_suffix('.bil') if bil_path.exists(): actual_data_path = bil_path print(f"找到BIL数据文件: {actual_data_path}") else: print(f"未找到数据文件,尝试的文件: {actual_data_path}, {bil_path}") else: # 输入是数据文件 actual_data_path = data_path else: # 如果没有头文件,使用默认值或参数值 if rows is None or cols is None or bands is None: print(f"找不到头文件 {header_path},无法读取图像") return None print(f"未找到头文件,使用手动指定的尺寸: {rows}x{cols}x{bands}") dtype = np.float32 interleave = 'bsq' # 默认 actual_data_path = data_path try: # 读取二进制数据 print(f"正在读取数据文件: {actual_data_path}, 尺寸: {rows}x{cols}x{bands}, 格式: {interleave}, 数据类型: {dtype}") # 使用更安全的方式读取数据 file_size = actual_data_path.stat().st_size expected_bytes = rows * cols * bands * dtype(0).itemsize print(f"文件大小: {file_size} 字节, 期望: {expected_bytes} 字节") if file_size != expected_bytes: print(f"文件大小不匹配!文件: {file_size} 字节, 计算: {expected_bytes} 字节") # 尝试继续读取,可能头文件有错误 print("尝试继续读取,但结果可能不正确...") # 读取数据 with open(actual_data_path, 'rb') as f: raw_data = np.frombuffer(f.read(), dtype=dtype) # 检查数据大小是否匹配 expected_size = rows * cols * bands if len(raw_data) != expected_size: print(f"数据大小不匹配。期望: {expected_size}, 实际: {len(raw_data)}") if len(raw_data) < expected_size: print("数据不完整,无法继续处理") return None else: print("数据超出期望大小,将截取所需部分") raw_data = raw_data[:expected_size] # 根据数据交织方式重塑数据 if interleave == 'bsq': # BSQ: Band Sequential (波段顺序) # 存储顺序: 波段1的所有行所有列 -> 波段2的所有行所有列 -> ... data = raw_data.reshape(bands, rows, cols) # 转换为标准格式: (rows, cols, bands) data = np.transpose(data, (1, 2, 0)) elif interleave == 'bil': # BIL: Band Interleaved by Line (按行交织) # 存储顺序: 行1的所有波段 -> 行2的所有波段 -> ... data = raw_data.reshape(rows, bands, cols) # 转换为标准格式: (rows, cols, bands) data = np.transpose(data, (0, 2, 1)) elif interleave == 'bip': # BIP: Band Interleaved by Pixel (按像素交织) # 存储顺序: 像素1的所有波段 -> 像素2的所有波段 -> ... data = raw_data.reshape(rows, cols, bands) # 已经是标准格式,无需转置 else: print(f"未知的交织格式: {interleave},尝试BSQ格式") data = raw_data.reshape(bands, rows, cols) data = np.transpose(data, (1, 2, 0)) print(f"成功读取,数据形状: {data.shape}, 数据类型: {data.dtype}") return data except Exception as e: print(f"读取数据文件时出错: {e}") return None def read_mask_file(self, mask_path: Path, reference_image_path: Optional[Path] = None) -> Optional[np.ndarray]: """读取遮罩文件""" try: # 支持多种遮罩文件格式 if mask_path.suffix.lower() in ['.bil', '.bsq', '.bip']: # ENVI格式遮罩 mask_data = self.read_envi_file(mask_path) if mask_data is not None: # 转换为二值遮罩 mask_data = mask_data.squeeze() mask_data = (mask_data > 0).astype(np.uint8) elif mask_path.suffix.lower() in ['.tif', '.tiff']: # TIFF格式(需要安装rasterio或PIL) try: from PIL import Image mask_img = Image.open(mask_path) mask_data = np.array(mask_img) # 转换为二值 mask_data = (mask_data > 0).astype(np.uint8) except ImportError: print("需要安装PIL(Pillow)来读取TIFF文件") return None elif mask_path.suffix.lower() in ['.png', '.jpg', '.jpeg']: # 图像格式遮罩 try: from PIL import Image mask_img = Image.open(mask_path) mask_data = np.array(mask_img) # 如果是彩色图像,转换为灰度 if len(mask_data.shape) == 3: mask_data = np.mean(mask_data, axis=2) # 转换为二值 mask_data = (mask_data > 0).astype(np.uint8) except ImportError: print("需要安装PIL(Pillow)来读取图像文件") return None elif mask_path.suffix.lower() == '.shp': # Shapefile格式,需要栅格化 mask_data = self._rasterize_shapefile(mask_path, reference_image_path) elif mask_path.suffix.lower() == '.dat': # DAT格式文件 mask_data = self._read_dat_file(mask_path) else: print(f"不支持的遮罩文件格式: {mask_path.suffix}") return None if mask_data is not None: print(f"成功读取遮罩文件: {mask_path.name}, 形状: {mask_data.shape}") return mask_data except Exception as e: print(f"读取遮罩文件时出错: {e}") return None def _rasterize_shapefile(self, shp_path: Path, reference_image_path: Optional[Path] = None) -> Optional[np.ndarray]: """将Shapefile栅格化""" try: import geopandas as gpd import rasterio from rasterio import features from rasterio.transform import from_bounds import numpy as np # 读取shapefile gdf = gpd.read_file(shp_path) if gdf.empty: print("Shapefile为空") return None # 获取参考图像的地理信息 if reference_image_path and reference_image_path.exists(): # 处理ENVI文件的路径问题 rasterio_path = str(reference_image_path) if reference_image_path.suffix.lower() == '.hdr': # 如果是头文件,尝试找到对应的数据文件 data_path = reference_image_path.with_suffix('') if data_path.exists(): rasterio_path = str(data_path) print(f"使用ENVI数据文件作为参考: {rasterio_path}") else: bil_path = reference_image_path.with_suffix('.bil') if bil_path.exists(): rasterio_path = str(bil_path) print(f"使用ENVI .bil文件作为参考: {rasterio_path}") else: print(f"找不到对应的ENVI数据文件,跳过参考图像") reference_image_path = None if reference_image_path: # 检查是否仍然有效 try: # 从参考图像获取尺寸和地理变换 with rasterio.open(rasterio_path) as ref_src: height, width = ref_src.height, ref_src.width transform = ref_src.transform crs = ref_src.crs print(f"使用参考图像的地理信息: {width}x{height}") except Exception as e: print(f"读取参考图像失败: {e},将使用Shapefile边界") reference_image_path = None else: # 如果没有参考图像,使用shapefile的边界 bounds = gdf.total_bounds # [minx, miny, maxx, maxy] width, height = 1000, 1000 # 默认尺寸 transform = from_bounds(bounds[0], bounds[1], bounds[2], bounds[3], width, height) crs = gdf.crs print(f"使用Shapefile边界创建栅格: {width}x{height}") # 创建栅格化掩码 mask_data = np.zeros((height, width), dtype=np.uint8) # 将几何图形栅格化 shapes = ((geom, 1) for geom in gdf.geometry) burned = features.rasterize( shapes=shapes, out=mask_data, transform=transform, all_touched=True, dtype=np.uint8 ) return burned except ImportError as e: print(f"栅格化Shapefile需要安装必要的库: {e}") print("请安装: pip install geopandas rasterio") return None except Exception as e: print(f"栅格化Shapefile时出错: {e}") return None def _read_dat_file(self, dat_path: Path) -> Optional[np.ndarray]: """读取DAT格式文件""" try: # DAT文件可能是二进制数据文件,尝试不同的读取方式 # 首先尝试作为文本文件读取 try: # 尝试读取为CSV格式 data = np.loadtxt(dat_path, delimiter=',') print(f"作为CSV格式读取DAT文件,形状: {data.shape}") # 如果是多列数据,取第一列作为掩码 if len(data.shape) > 1: mask_data = data[:, 0] if data.shape[1] > 1 else data.flatten() else: mask_data = data.flatten() except ValueError: # 如果不是文本格式,尝试作为二进制文件读取 print("尝试作为二进制文件读取DAT文件") # 读取二进制数据 with open(dat_path, 'rb') as f: raw_data = f.read() # 尝试推断数据类型和尺寸(这可能需要调整) # 假设是float32类型,需要根据实际情况调整 try: data = np.frombuffer(raw_data, dtype=np.float32) print(f"作为float32二进制读取,元素数量: {len(data)}") # 尝试推断二维结构(这只是一个假设) # 可能需要手动指定尺寸或从其他地方获取 n = int(np.sqrt(len(data))) if n * n == len(data): mask_data = data.reshape(n, n) print(f"重塑为方阵: {n}x{n}") else: # 如果无法形成方阵,保持一维 mask_data = data print("保持一维数组") except Exception as e: print(f"二进制读取失败: {e}") return None # 转换为二值掩码 mask_data = (mask_data > 0).astype(np.uint8) return mask_data except Exception as e: print(f"读取DAT文件时出错: {e}") return None @dataclass class RegressionPredictor: """回归预测器类""" def __init__(self, config: PredictionConfig): """ 初始化回归预测器 参数: config: 预测配置对象 """ self.config = config self.image_reader = ImageReader() self.models = {} # 存储加载的模型 self.scalers = {} # 存储加载的标准化器 self.model_info = {} # 存储模型信息 # 创建输出目录 self.output_dir = Path(config.output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) def load_models(self) -> bool: """加载模型文件""" try: if isinstance(self.config.model_path, str): # 单模型文件或文件夹 model_path = Path(self.config.model_path) if model_path.is_file(): # 单文件 return self._load_single_model(model_path) elif model_path.is_dir(): # 文件夹中的多个模型 return self._load_models_from_directory(model_path) else: print(f"模型路径不存在: {model_path}") return False elif isinstance(self.config.model_path, list): # 模型文件列表 for model_path in self.config.model_path: if not self._load_single_model(Path(model_path)): print(f"加载模型失败: {model_path}") return False return True else: print("无效的模型路径格式") return False except Exception as e: print(f"加载模型时出错: {e}") return False def _load_single_model(self, model_path: Path) -> bool: """加载单个模型""" try: # 构建相关文件路径 model_file = None scaler_file = None info_file = None if model_path.suffix == '.pkl': # 直接是模型文件 model_file = model_path # 尝试寻找对应的标准化器和信息文件 scaler_file = model_path.parent / f"scaler_{model_path.stem.split('_')[-1]}.pkl" info_file = model_path.parent / f"info_{model_path.stem}.json" else: # 可能是文件夹或错误路径 print(f"不支持的模型文件格式: {model_path}") return False # 加载模型 if model_file.exists(): model = joblib.load(model_file) model_name = model_file.stem self.models[model_name] = model print(f"加载模型: {model_name}") else: print(f"模型文件不存在: {model_file}") return False # 加载标准化器 if scaler_file and scaler_file.exists(): scaler = joblib.load(scaler_file) self.scalers[model_name] = scaler print(f"加载标准化器: {scaler_file.name}") else: print(f"标准化器文件不存在,将不使用标准化: {scaler_file}") # 不使用标准化器,设置为None self.scalers[model_name] = None # 加载模型信息 if info_file and info_file.exists(): with open(info_file, 'r', encoding='utf-8') as f: info = json.load(f) self.model_info[model_name] = info print(f"加载模型信息: {info_file.name}") else: print(f"模型信息文件不存在: {info_file}") # 创建基本信息 self.model_info[model_name] = { 'model_name': model_name, 'timestamp': 'unknown' } return True except Exception as e: print(f"加载单个模型时出错: {e}") return False def _load_models_from_directory(self, model_dir: Path) -> bool: """从目录加载多个模型""" try: # 查找所有.pkl文件 model_files = list(model_dir.glob('*.pkl')) model_files = [f for f in model_files if not f.name.startswith('scaler_')] if not model_files: print(f"在目录 {model_dir} 中未找到模型文件") return False print(f"在目录中找到 {len(model_files)} 个模型文件") # 加载每个模型 for model_file in model_files: if not self._load_single_model(model_file): print(f"加载模型失败: {model_file}") return False return True except Exception as e: print(f"从目录加载模型时出错: {e}") return False def load_image_and_mask(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: """加载图像和遮罩""" # 加载高光谱图像 image_path = Path(self.config.image_path) if not image_path.exists(): print(f"图像文件不存在: {image_path}") return None, None image_data = self.image_reader.read_envi_file(image_path) if image_data is None: print("无法读取图像文件") return None, None # 加载遮罩(如果指定) mask_data = None if self.config.use_mask and self.config.mask_path: mask_path = Path(self.config.mask_path) if mask_path.exists(): # 为Shapefile栅格化传递参考图像路径 reference_path = image_path if mask_path.suffix.lower() == '.shp' else None mask_data = self.image_reader.read_mask_file(mask_path, reference_path) if mask_data is None: print("无法读取遮罩文件,将不使用遮罩") else: print(f"遮罩文件不存在: {mask_path},将不使用遮罩") return image_data, mask_data def preprocess_image(self, image_data: np.ndarray, mask_data: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]: """预处理图像数据""" # 获取图像尺寸 rows, cols, bands = image_data.shape # 重塑为二维数组 (pixels, bands) 用于预测 X_image = image_data.reshape(-1, bands) # 处理遮罩 valid_mask = None if mask_data is not None and self.config.use_mask: # 确保遮罩尺寸匹配 if mask_data.shape[:2] != (rows, cols): print(f"遮罩尺寸不匹配: 图像{rows}x{cols}, 遮罩{mask_data.shape[:2]}") # 尝试调整遮罩尺寸 try: from skimage.transform import resize mask_data_resized = resize(mask_data.astype(float), (rows, cols), order=0, preserve_range=True) mask_data = mask_data_resized.astype(np.uint8) print("已调整遮罩尺寸以匹配图像") except ImportError: print("需要安装scikit-image来调整遮罩尺寸") mask_data = None if mask_data is not None: # 创建有效像素掩码 valid_mask = mask_data.reshape(-1) > 0 print(f"使用遮罩,有效像素: {valid_mask.sum()}/{len(valid_mask)}") # 只保留有效像素 X_image = X_image[valid_mask] else: print("不使用遮罩,处理所有像素") return X_image, valid_mask def predict_image(self, image_data: np.ndarray, mask_data: Optional[np.ndarray] = None) -> Dict[str, np.ndarray]: """对图像进行预测""" # 预处理图像 X_image, valid_mask = self.preprocess_image(image_data, mask_data) if len(X_image) == 0: print("没有有效像素进行预测") return {} predictions = {} # 对每个加载的模型进行预测 for model_name, model in self.models.items(): try: print(f"使用模型 {model_name} 进行预测...") # 获取标准化器 scaler = self.scalers.get(model_name) if scaler: # 标准化数据 X_scaled = scaler.transform(X_image) else: X_scaled = X_image # 进行预测 y_pred = model.predict(X_scaled) # 重建预测图像 rows, cols, _ = image_data.shape pred_image = np.full((rows * cols,), np.nan, dtype=np.float32) if valid_mask is not None: # 只填充有效像素 pred_image[valid_mask] = y_pred else: # 填充所有像素 pred_image[:] = y_pred # 重塑为图像格式 pred_image = pred_image.reshape(rows, cols) predictions[model_name] = pred_image print(f"模型 {model_name} 预测完成") except Exception as e: print(f"使用模型 {model_name} 预测时出错: {e}") continue return predictions def save_prediction_results(self, predictions: Dict[str, np.ndarray], image_data: np.ndarray, mask_data: Optional[np.ndarray] = None) -> List[Path]: """保存预测结果""" saved_files = [] image_name = Path(self.config.image_path).stem for model_name, pred_image in predictions.items(): try: # 保存PNG可视化结果 if self.config.save_individual: output_name = f"{image_name}_{model_name}_prediction.png" else: output_name = f"{image_name}_prediction.png" output_path = self.output_dir / output_name # 可视化预测结果 self._visualize_prediction(pred_image, model_name, output_path, mask_data) saved_files.append(output_path) print(f"保存预测结果PNG: {output_path}") # 保存ENVI格式的单波段文件 envi_files = self._save_envi_prediction(pred_image, model_name, image_name, image_data.shape) saved_files.extend(envi_files) except Exception as e: print(f"保存模型 {model_name} 的预测结果时出错: {e}") continue return saved_files def _save_envi_prediction(self, pred_image: np.ndarray, model_name: str, image_name: str, original_shape: tuple) -> List[Path]: """保存预测结果为ENVI格式的单波段文件""" saved_files = [] try: # 创建输出文件名 if self.config.save_individual: base_name = f"{image_name}_{model_name}_prediction" else: base_name = f"{image_name}_prediction" bil_path = self.output_dir / f"{base_name}.bil" hdr_path = self.output_dir / f"{base_name}.hdr" # 处理NaN值 pred_image_clean = pred_image.copy() if np.any(np.isnan(pred_image_clean)): # 用最小值填充NaN min_val = np.nanmin(pred_image_clean) pred_image_clean = np.nan_to_num(pred_image_clean, nan=min_val) # 确保数据类型适合保存 pred_image_clean = pred_image_clean.astype(np.float32) # 保存为BIL格式 (Band Interleaved by Line) rows, cols = pred_image_clean.shape # BIL格式: 按行存储,每个像素的所有波段连续存储 # 对于单波段数据,BIL格式就是简单的行优先存储 pred_image_clean.tofile(bil_path) print(f"保存ENVI数据文件: {bil_path}") # 生成对应的头文件 self._create_envi_header(hdr_path, rows, cols, 1, bil_path, pred_image_clean.dtype) saved_files.extend([bil_path, hdr_path]) print(f"保存ENVI头文件: {hdr_path}") except Exception as e: print(f"保存ENVI格式预测结果时出错: {e}") return saved_files def _create_envi_header(self, hdr_path: Path, rows: int, cols: int, bands: int, data_path: Path, dtype: np.dtype): """创建ENVI格式的头文件""" try: # 获取数据类型映射 envi_data_type = self._get_envi_data_type(dtype) header_content = f"""ENVI description = {{ Regression prediction result saved as ENVI format}} samples = {cols} lines = {rows} bands = {bands} header offset = 0 file type = ENVI Standard data type = {envi_data_type} interleave = bil sensor type = Unknown byte order = 0 """ with open(hdr_path, 'w', encoding='utf-8') as f: f.write(header_content) except Exception as e: print(f"创建ENVI头文件时出错: {e}") def _get_envi_data_type(self, dtype: np.dtype) -> int: """将numpy数据类型转换为ENVI数据类型""" dtype_map = { np.uint8: 1, # 8-bit byte np.int16: 2, # 16-bit signed integer np.int32: 3, # 32-bit signed long integer np.float32: 4, # 32-bit floating point np.float64: 5, # 64-bit double precision floating point np.uint16: 12, # 16-bit unsigned integer np.uint32: 13, # 32-bit unsigned long integer np.int64: 14, # 64-bit signed long integer np.uint64: 15, # 64-bit unsigned long integer } return dtype_map.get(dtype.type, 4) # 默认float32 def _visualize_prediction(self, pred_image: np.ndarray, model_name: str, output_path: Path, mask_data: Optional[np.ndarray] = None): """可视化预测结果""" # 处理NaN值 pred_image_clean = pred_image.copy() if np.any(np.isnan(pred_image_clean)): # 用最小值填充NaN min_val = np.nanmin(pred_image_clean) pred_image_clean = np.nan_to_num(pred_image_clean, nan=min_val) # 创建图形 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) # 原始预测结果 im1 = ax1.imshow(pred_image_clean, cmap=self.config.colormap) ax1.set_title(f'{model_name} - Prediction Results') ax1.axis('off') plt.colorbar(im1, ax=ax1, shrink=0.8) # 如果有遮罩,显示遮罩区域 if mask_data is not None: # 创建RGB图像叠加遮罩 rgb_image = np.zeros((*pred_image_clean.shape, 3)) # 归一化预测结果到0-1 pred_norm = (pred_image_clean - np.nanmin(pred_image_clean)) / (np.nanmax(pred_image_clean) - np.nanmin(pred_image_clean)) pred_norm = np.nan_to_num(pred_norm, nan=0) # 应用颜色映射 cmap = plt.cm.get_cmap(self.config.colormap) colored_pred = cmap(pred_norm) rgb_image = colored_pred[:, :, :3] # 叠加遮罩边界 from skimage import measure try: contours = measure.find_contours(mask_data, 0.5) for contour in contours: # 绘制遮罩边界 ax2.plot(contour[:, 1], contour[:, 0], color='red', linewidth=1, alpha=0.7) except ImportError: pass ax2.imshow(rgb_image) ax2.set_title(f'{model_name} - With Mask Overlay') else: # 显示预测结果的直方图 ax2.hist(pred_image_clean.flatten(), bins=50, alpha=0.7, color='blue', edgecolor='black') ax2.set_title(f'{model_name} - Prediction Distribution') ax2.set_xlabel('Predicted Value') ax2.set_ylabel('Frequency') ax2.axis('off') if mask_data is not None else None # 添加统计信息 stats_text = f""" Statistics: Min: {np.nanmin(pred_image):.4f} Max: {np.nanmax(pred_image):.4f} Mean: {np.nanmean(pred_image):.4f} Std: {np.nanstd(pred_image):.4f} Valid Pixels: {np.sum(~np.isnan(pred_image))}/{pred_image.size} """ fig.suptitle(f'Regression Prediction Results - {model_name}', fontsize=14, fontweight='bold') plt.figtext(0.02, 0.02, stats_text, fontsize=10, verticalalignment='bottom', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8)) plt.tight_layout() # 保存图像 plt.savefig(output_path, dpi=self.config.dpi, bbox_inches='tight') plt.close() def predict_single_image(self) -> bool: """预测单个图像""" print(f"开始预测图像: {self.config.image_path}") # 加载图像和遮罩 image_data, mask_data = self.load_image_and_mask() if image_data is None: return False # 加载模型 if not self.load_models(): return False # 进行预测 predictions = self.predict_image(image_data, mask_data) if not predictions: return False # 保存结果 saved_files = self.save_prediction_results(predictions, image_data, mask_data) print(f"预测完成!生成了 {len(saved_files)} 个结果文件") print(f"输出目录: {self.output_dir}") return True def predict_batch_images(self) -> bool: """批量预测多个图像""" print("开始批量预测...") # 获取图像列表 image_path = Path(self.config.image_path) if image_path.is_file(): image_files = [image_path] elif image_path.is_dir(): # 查找目录中的所有支持的图像文件 extensions = ['.bil', '.bsq', '.bip'] image_files = [] for ext in extensions: image_files.extend(list(image_path.glob(f'*{ext}'))) image_files = sorted(list(set(image_files))) # 去重并排序 else: print(f"无效的图像路径: {image_path}") return False if not image_files: print("未找到图像文件") return False print(f"找到 {len(image_files)} 个图像文件") # 加载模型(只加载一次) if not self.load_models(): return False total_success = 0 for i, image_file in enumerate(image_files, 1): print(f"\n处理图像 {i}/{len(image_files)}: {image_file.name}") print("-" * 50) # 更新配置中的图像路径 self.config.image_path = str(image_file) # 尝试预测单个图像 if self.predict_single_image(): total_success += 1 else: print(f"预测失败: {image_file.name}") print(f"\n批量预测完成!") print(f"成功预测: {total_success}/{len(image_files)} 个图像") return total_success > 0 def run_prediction(self) -> bool: """运行预测流程""" start_time = time.time() try: if self.config.batch_mode: success = self.predict_batch_images() else: success = self.predict_single_image() elapsed_time = time.time() - start_time return success except Exception as e: elapsed_time = time.time() - start_time print(f"错误详情: {e}") return False def create_default_config(image_path: str, model_path: str, mask_path: Optional[str] = None, output_dir: str = "prediction_results", use_mask: bool = True) -> PredictionConfig: """创建默认预测配置""" return PredictionConfig( image_path=image_path, mask_path=mask_path, model_path=model_path, output_dir=output_dir, use_mask=use_mask, batch_mode=False, colormap='viridis', dpi=300, save_individual=True ) # def main(): # """主函数""" # parser = argparse.ArgumentParser(description='回归预测工具 - 对高光谱图像进行含量预测') # # 必需参数 # parser.add_argument('image_path', help='高光谱图像文件路径或目录') # parser.add_argument('model_path', help='模型文件路径(单个文件或目录)') # # 可选参数 # parser.add_argument('--mask_path', help='遮罩文件路径') # parser.add_argument('--output_dir', default='prediction_results', help='输出目录') # parser.add_argument('--no_mask', action='store_true', help='不使用遮罩') # parser.add_argument('--batch', action='store_true', help='批量处理模式') # parser.add_argument('--colormap', default='viridis', help='颜色映射') # parser.add_argument('--dpi', type=int, default=300, help='输出图像DPI') # parser.add_argument('--no_individual', action='store_true', help='不保存单个模型结果') # args = parser.parse_args() # # 创建配置 # config = PredictionConfig( # image_path=args.image_path, # mask_path=args.mask_path, # model_path=args.model_path, # output_dir=args.output_dir, # use_mask=not args.no_mask, # batch_mode=args.batch, # colormap=args.colormap, # dpi=args.dpi, # save_individual=not args.no_individual # ) # # 创建预测器并运行 # predictor = RegressionPredictor(config) # success = predictor.run_prediction() # return 0 if success else 1 def main(): """主函数""" # 必需参数 image_path = r"D:\resonon\RegressionTutorial\7_minute.bil" model_path = r"E:\code\spectronon\single_classsfication\rgression_method\models" # 可选参数 mask_path = r"E:\code\spectronon\single_classsfication\rgression_method\data\roi.shp" output_dir = r"E:\code\spectronon\single_classsfication\output" mask = True batch = True colormap = 'viridis' dpi = 300 no_individual = False # 创建配置 config = PredictionConfig( image_path=image_path, mask_path=mask_path, model_path=model_path, output_dir=output_dir, use_mask=mask, batch_mode=batch, colormap=colormap, dpi=dpi, save_individual=not no_individual ) # 创建预测器并运行 predictor = RegressionPredictor(config) success = predictor.run_prediction() return 0 if success else 1 if __name__ == "__main__": exit(main())