Files
HSI/rgression_method/regression_predict.py

1096 lines
42 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
回归预测工具包
支持使用训练好的回归模型对高光谱图像进行预测
包含图像读取、遮罩处理、批量预测、结果可视化等功能
"""
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import glob
import joblib
import json
import warnings
import argparse
from typing import Optional, List, Dict, Any, Union, Tuple
from dataclasses import dataclass, field
import time
# 导入回归分析器和所有必要的类为了pickle兼容性
try:
# 当作为模块导入时使用相对导入
from .regression import (
RegressionAnalyzer,
ExtremeLearningMachine,
GeneralizedAdditiveModel,
LSTMRegressor,
GRURegressor
)
except ImportError:
# 当直接运行脚本时使用绝对导入
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
from regression import (
RegressionAnalyzer,
ExtremeLearningMachine,
GeneralizedAdditiveModel,
LSTMRegressor,
GRURegressor
)
warnings.filterwarnings('ignore')
@dataclass
class PredictionConfig:
"""预测配置类"""
image_path: str = "" # 高光谱图像路径
mask_path: Optional[str] = None # 遮罩文件路径
model_path: Union[str, List[str]] = "" # 模型文件路径(单文件或列表)
output_dir: str = "prediction_results" # 输出目录
use_mask: bool = True # 是否使用遮罩
batch_mode: bool = False # 是否批量处理
colormap: str = 'viridis' # 颜色映射
dpi: int = 300 # 输出图像DPI
save_individual: bool = True # 是否保存单个预测结果
@dataclass
class ImageReader:
"""高光谱图像读取器"""
def __init__(self):
pass
def read_envi_header(self, header_path: Path) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int], str, int]:
"""读取ENVI头文件获取图像尺寸和格式信息"""
rows, cols, bands = None, None, None
data_type = None
interleave = 'bsq' # 默认值
byte_order = 0 # 默认小端字节序
try:
with open(header_path, 'r', encoding='utf-8') as f:
content = f.read()
# 解析关键参数
lines = content.split('\n')
for line in lines:
line = line.strip()
if '=' in line:
key, value = line.split('=', 1)
key = key.strip()
value = value.strip()
if key.lower() == 'samples':
cols = int(value)
elif key.lower() == 'lines':
rows = int(value)
elif key.lower() == 'bands':
bands = int(value)
elif key.lower() == 'data type':
data_type = int(value)
elif key.lower() == 'interleave':
interleave = value.lower() # bsq, bil, bip
elif key.lower() == 'byte order':
byte_order = int(value)
if rows and cols and bands:
print(f"从头文件读取: {rows}x{cols}x{bands}, interleave={interleave}, data_type={data_type}")
except Exception as e:
print(f"读取头文件时出错: {e}")
return rows, cols, bands, data_type, interleave, byte_order
def get_numpy_dtype(self, envi_data_type: int) -> np.dtype:
"""将ENVI数据类型转换为numpy数据类型"""
dtype_map = {
1: np.uint8, # 8-bit byte
2: np.int16, # 16-bit signed integer
3: np.int32, # 32-bit signed long integer
4: np.float32, # 32-bit floating point
5: np.float64, # 64-bit double precision floating point
12: np.uint16, # 16-bit unsigned integer
13: np.uint32, # 32-bit unsigned long integer
14: np.int64, # 64-bit signed long integer
15: np.uint64, # 64-bit unsigned long integer
}
return dtype_map.get(envi_data_type, np.float32)
def read_envi_file(self, data_path: Path, rows: Optional[int] = None, cols: Optional[int] = None,
bands: Optional[int] = None) -> Optional[np.ndarray]:
"""
读取ENVI格式文件支持多种读取方式
优先使用spectral库如果不可用则使用自定义方法
"""
"""
读取ENVI格式文件支持BSQ, BIL, BIP格式
参数:
data_path: 数据文件路径
rows: 图像行数(高度)
cols: 图像列数(宽度)
bands: 波段数
返回:
图像数据数组,形状为(rows, cols, bands)
"""
# 尝试使用spectral库读取ENVI文件如果可用
try:
import spectral
print("使用spectral库读取ENVI文件")
# spectral库可以自动处理头文件和数据文件的匹配
img = spectral.open_image(str(data_path))
data = img.load()
print(f"spectral库成功读取数据形状: {data.shape}, 数据类型: {data.dtype}")
return data
except ImportError:
print("spectral库不可用使用自定义方法读取")
except Exception as e:
print(f"spectral库读取失败: {e},使用自定义方法")
# 使用自定义方法读取
# 尝试读取头文件获取尺寸信息
# 首先检查输入路径是否已经是头文件
if data_path.suffix.lower() == '.hdr':
header_path = data_path
else:
# 输入是数据文件,尝试寻找对应的头文件
# 首先尝试直接替换扩展名为.hdr
header_path = data_path.with_suffix('.hdr')
# 如果不存在,尝试其他可能的命名方式
if not header_path.exists():
# 对于.bil文件头文件可能是.bil.hdr
alt_header_path = data_path.parent / f"{data_path.name}.hdr"
if alt_header_path.exists():
header_path = alt_header_path
else:
# 或者文件名.hdr去掉扩展名
base_header_path = data_path.parent / f"{data_path.stem}.hdr"
if base_header_path.exists():
header_path = base_header_path
if header_path.exists():
print(f"正在读取头文件: {header_path}")
rows, cols, bands, data_type, interleave, byte_order = self.read_envi_header(header_path)
if data_type is None:
data_type = 4 # 默认float32
dtype = self.get_numpy_dtype(data_type)
# 确定数据文件路径
if data_path.suffix.lower() == '.hdr':
# 输入是头文件,对应的数据文件是去掉.hdr扩展名的文件
actual_data_path = data_path.with_suffix('')
print(f"输入是头文件,计算出的数据文件路径: {actual_data_path}")
# 检查常见的数据文件扩展名
if not actual_data_path.exists():
# 尝试添加.bil扩展名
bil_path = actual_data_path.with_suffix('.bil')
if bil_path.exists():
actual_data_path = bil_path
print(f"找到BIL数据文件: {actual_data_path}")
else:
print(f"未找到数据文件,尝试的文件: {actual_data_path}, {bil_path}")
else:
# 输入是数据文件
actual_data_path = data_path
else:
# 如果没有头文件,使用默认值或参数值
if rows is None or cols is None or bands is None:
print(f"找不到头文件 {header_path},无法读取图像")
return None
print(f"未找到头文件,使用手动指定的尺寸: {rows}x{cols}x{bands}")
dtype = np.float32
interleave = 'bsq' # 默认
actual_data_path = data_path
try:
# 读取二进制数据
print(f"正在读取数据文件: {actual_data_path}, 尺寸: {rows}x{cols}x{bands}, 格式: {interleave}, 数据类型: {dtype}")
# 使用更安全的方式读取数据
file_size = actual_data_path.stat().st_size
expected_bytes = rows * cols * bands * dtype(0).itemsize
print(f"文件大小: {file_size} 字节, 期望: {expected_bytes} 字节")
if file_size != expected_bytes:
print(f"文件大小不匹配!文件: {file_size} 字节, 计算: {expected_bytes} 字节")
# 尝试继续读取,可能头文件有错误
print("尝试继续读取,但结果可能不正确...")
# 读取数据
with open(actual_data_path, 'rb') as f:
raw_data = np.frombuffer(f.read(), dtype=dtype)
# 检查数据大小是否匹配
expected_size = rows * cols * bands
if len(raw_data) != expected_size:
print(f"数据大小不匹配。期望: {expected_size}, 实际: {len(raw_data)}")
if len(raw_data) < expected_size:
print("数据不完整,无法继续处理")
return None
else:
print("数据超出期望大小,将截取所需部分")
raw_data = raw_data[:expected_size]
# 根据数据交织方式重塑数据
if interleave == 'bsq':
# BSQ: Band Sequential (波段顺序)
# 存储顺序: 波段1的所有行所有列 -> 波段2的所有行所有列 -> ...
data = raw_data.reshape(bands, rows, cols)
# 转换为标准格式: (rows, cols, bands)
data = np.transpose(data, (1, 2, 0))
elif interleave == 'bil':
# BIL: Band Interleaved by Line (按行交织)
# 存储顺序: 行1的所有波段 -> 行2的所有波段 -> ...
data = raw_data.reshape(rows, bands, cols)
# 转换为标准格式: (rows, cols, bands)
data = np.transpose(data, (0, 2, 1))
elif interleave == 'bip':
# BIP: Band Interleaved by Pixel (按像素交织)
# 存储顺序: 像素1的所有波段 -> 像素2的所有波段 -> ...
data = raw_data.reshape(rows, cols, bands)
# 已经是标准格式,无需转置
else:
print(f"未知的交织格式: {interleave}尝试BSQ格式")
data = raw_data.reshape(bands, rows, cols)
data = np.transpose(data, (1, 2, 0))
print(f"成功读取,数据形状: {data.shape}, 数据类型: {data.dtype}")
return data
except Exception as e:
print(f"读取数据文件时出错: {e}")
return None
def read_mask_file(self, mask_path: Path, reference_image_path: Optional[Path] = None) -> Optional[np.ndarray]:
"""读取遮罩文件"""
try:
# 支持多种遮罩文件格式
if mask_path.suffix.lower() in ['.bil', '.bsq', '.bip']:
# ENVI格式遮罩
mask_data = self.read_envi_file(mask_path)
if mask_data is not None:
# 转换为二值遮罩
mask_data = mask_data.squeeze()
mask_data = (mask_data > 0).astype(np.uint8)
elif mask_path.suffix.lower() in ['.tif', '.tiff']:
# TIFF格式需要安装rasterio或PIL
try:
from PIL import Image
mask_img = Image.open(mask_path)
mask_data = np.array(mask_img)
# 转换为二值
mask_data = (mask_data > 0).astype(np.uint8)
except ImportError:
print("需要安装PIL(Pillow)来读取TIFF文件")
return None
elif mask_path.suffix.lower() in ['.png', '.jpg', '.jpeg']:
# 图像格式遮罩
try:
from PIL import Image
mask_img = Image.open(mask_path)
mask_data = np.array(mask_img)
# 如果是彩色图像,转换为灰度
if len(mask_data.shape) == 3:
mask_data = np.mean(mask_data, axis=2)
# 转换为二值
mask_data = (mask_data > 0).astype(np.uint8)
except ImportError:
print("需要安装PIL(Pillow)来读取图像文件")
return None
elif mask_path.suffix.lower() == '.shp':
# Shapefile格式需要栅格化
mask_data = self._rasterize_shapefile(mask_path, reference_image_path)
elif mask_path.suffix.lower() == '.dat':
# DAT格式文件
mask_data = self._read_dat_file(mask_path)
else:
print(f"不支持的遮罩文件格式: {mask_path.suffix}")
return None
if mask_data is not None:
print(f"成功读取遮罩文件: {mask_path.name}, 形状: {mask_data.shape}")
return mask_data
except Exception as e:
print(f"读取遮罩文件时出错: {e}")
return None
def _rasterize_shapefile(self, shp_path: Path, reference_image_path: Optional[Path] = None) -> Optional[np.ndarray]:
"""将Shapefile栅格化"""
try:
import geopandas as gpd
import rasterio
from rasterio import features
from rasterio.transform import from_bounds
import numpy as np
# 读取shapefile
gdf = gpd.read_file(shp_path)
if gdf.empty:
print("Shapefile为空")
return None
# 获取参考图像的地理信息
if reference_image_path and reference_image_path.exists():
# 处理ENVI文件的路径问题
rasterio_path = str(reference_image_path)
if reference_image_path.suffix.lower() == '.hdr':
# 如果是头文件,尝试找到对应的数据文件
data_path = reference_image_path.with_suffix('')
if data_path.exists():
rasterio_path = str(data_path)
print(f"使用ENVI数据文件作为参考: {rasterio_path}")
else:
bil_path = reference_image_path.with_suffix('.bil')
if bil_path.exists():
rasterio_path = str(bil_path)
print(f"使用ENVI .bil文件作为参考: {rasterio_path}")
else:
print(f"找不到对应的ENVI数据文件跳过参考图像")
reference_image_path = None
if reference_image_path: # 检查是否仍然有效
try:
# 从参考图像获取尺寸和地理变换
with rasterio.open(rasterio_path) as ref_src:
height, width = ref_src.height, ref_src.width
transform = ref_src.transform
crs = ref_src.crs
print(f"使用参考图像的地理信息: {width}x{height}")
except Exception as e:
print(f"读取参考图像失败: {e}将使用Shapefile边界")
reference_image_path = None
else:
# 如果没有参考图像使用shapefile的边界
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
width, height = 1000, 1000 # 默认尺寸
transform = from_bounds(bounds[0], bounds[1], bounds[2], bounds[3], width, height)
crs = gdf.crs
print(f"使用Shapefile边界创建栅格: {width}x{height}")
# 创建栅格化掩码
mask_data = np.zeros((height, width), dtype=np.uint8)
# 将几何图形栅格化
shapes = ((geom, 1) for geom in gdf.geometry)
burned = features.rasterize(
shapes=shapes,
out=mask_data,
transform=transform,
all_touched=True,
dtype=np.uint8
)
return burned
except ImportError as e:
print(f"栅格化Shapefile需要安装必要的库: {e}")
print("请安装: pip install geopandas rasterio")
return None
except Exception as e:
print(f"栅格化Shapefile时出错: {e}")
return None
def _read_dat_file(self, dat_path: Path) -> Optional[np.ndarray]:
"""读取DAT格式文件"""
try:
# DAT文件可能是二进制数据文件尝试不同的读取方式
# 首先尝试作为文本文件读取
try:
# 尝试读取为CSV格式
data = np.loadtxt(dat_path, delimiter=',')
print(f"作为CSV格式读取DAT文件形状: {data.shape}")
# 如果是多列数据,取第一列作为掩码
if len(data.shape) > 1:
mask_data = data[:, 0] if data.shape[1] > 1 else data.flatten()
else:
mask_data = data.flatten()
except ValueError:
# 如果不是文本格式,尝试作为二进制文件读取
print("尝试作为二进制文件读取DAT文件")
# 读取二进制数据
with open(dat_path, 'rb') as f:
raw_data = f.read()
# 尝试推断数据类型和尺寸(这可能需要调整)
# 假设是float32类型需要根据实际情况调整
try:
data = np.frombuffer(raw_data, dtype=np.float32)
print(f"作为float32二进制读取元素数量: {len(data)}")
# 尝试推断二维结构(这只是一个假设)
# 可能需要手动指定尺寸或从其他地方获取
n = int(np.sqrt(len(data)))
if n * n == len(data):
mask_data = data.reshape(n, n)
print(f"重塑为方阵: {n}x{n}")
else:
# 如果无法形成方阵,保持一维
mask_data = data
print("保持一维数组")
except Exception as e:
print(f"二进制读取失败: {e}")
return None
# 转换为二值掩码
mask_data = (mask_data > 0).astype(np.uint8)
return mask_data
except Exception as e:
print(f"读取DAT文件时出错: {e}")
return None
@dataclass
class RegressionPredictor:
"""回归预测器类"""
def __init__(self, config: PredictionConfig):
"""
初始化回归预测器
参数:
config: 预测配置对象
"""
self.config = config
self.image_reader = ImageReader()
self.models = {} # 存储加载的模型
self.scalers = {} # 存储加载的标准化器
self.model_info = {} # 存储模型信息
# 创建输出目录
self.output_dir = Path(config.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def load_models(self) -> bool:
"""加载模型文件"""
try:
if isinstance(self.config.model_path, str):
# 单模型文件或文件夹
model_path = Path(self.config.model_path)
if model_path.is_file():
# 单文件
return self._load_single_model(model_path)
elif model_path.is_dir():
# 文件夹中的多个模型
return self._load_models_from_directory(model_path)
else:
print(f"模型路径不存在: {model_path}")
return False
elif isinstance(self.config.model_path, list):
# 模型文件列表
for model_path in self.config.model_path:
if not self._load_single_model(Path(model_path)):
print(f"加载模型失败: {model_path}")
return False
return True
else:
print("无效的模型路径格式")
return False
except Exception as e:
print(f"加载模型时出错: {e}")
return False
def _load_single_model(self, model_path: Path) -> bool:
"""加载单个模型"""
try:
# 构建相关文件路径
model_file = None
scaler_file = None
info_file = None
if model_path.suffix == '.pkl':
# 直接是模型文件
model_file = model_path
# 尝试寻找对应的标准化器和信息文件
scaler_file = model_path.parent / f"scaler_{model_path.stem.split('_')[-1]}.pkl"
info_file = model_path.parent / f"info_{model_path.stem}.json"
else:
# 可能是文件夹或错误路径
print(f"不支持的模型文件格式: {model_path}")
return False
# 加载模型
if model_file.exists():
model = joblib.load(model_file)
model_name = model_file.stem
self.models[model_name] = model
print(f"加载模型: {model_name}")
else:
print(f"模型文件不存在: {model_file}")
return False
# 加载标准化器
if scaler_file and scaler_file.exists():
scaler = joblib.load(scaler_file)
self.scalers[model_name] = scaler
print(f"加载标准化器: {scaler_file.name}")
else:
print(f"标准化器文件不存在,将不使用标准化: {scaler_file}")
# 不使用标准化器设置为None
self.scalers[model_name] = None
# 加载模型信息
if info_file and info_file.exists():
with open(info_file, 'r', encoding='utf-8') as f:
info = json.load(f)
self.model_info[model_name] = info
print(f"加载模型信息: {info_file.name}")
else:
print(f"模型信息文件不存在: {info_file}")
# 创建基本信息
self.model_info[model_name] = {
'model_name': model_name,
'timestamp': 'unknown'
}
return True
except Exception as e:
print(f"加载单个模型时出错: {e}")
return False
def _load_models_from_directory(self, model_dir: Path) -> bool:
"""从目录加载多个模型"""
try:
# 查找所有.pkl文件
model_files = list(model_dir.glob('*.pkl'))
model_files = [f for f in model_files if not f.name.startswith('scaler_')]
if not model_files:
print(f"在目录 {model_dir} 中未找到模型文件")
return False
print(f"在目录中找到 {len(model_files)} 个模型文件")
# 加载每个模型
for model_file in model_files:
if not self._load_single_model(model_file):
print(f"加载模型失败: {model_file}")
return False
return True
except Exception as e:
print(f"从目录加载模型时出错: {e}")
return False
def load_image_and_mask(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""加载图像和遮罩"""
# 加载高光谱图像
image_path = Path(self.config.image_path)
if not image_path.exists():
print(f"图像文件不存在: {image_path}")
return None, None
image_data = self.image_reader.read_envi_file(image_path)
if image_data is None:
print("无法读取图像文件")
return None, None
# 加载遮罩(如果指定)
mask_data = None
if self.config.use_mask and self.config.mask_path:
mask_path = Path(self.config.mask_path)
if mask_path.exists():
# 为Shapefile栅格化传递参考图像路径
reference_path = image_path if mask_path.suffix.lower() == '.shp' else None
mask_data = self.image_reader.read_mask_file(mask_path, reference_path)
if mask_data is None:
print("无法读取遮罩文件,将不使用遮罩")
else:
print(f"遮罩文件不存在: {mask_path},将不使用遮罩")
return image_data, mask_data
def preprocess_image(self, image_data: np.ndarray, mask_data: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""预处理图像数据"""
# 获取图像尺寸
rows, cols, bands = image_data.shape
# 重塑为二维数组 (pixels, bands) 用于预测
X_image = image_data.reshape(-1, bands)
# 处理遮罩
valid_mask = None
if mask_data is not None and self.config.use_mask:
# 确保遮罩尺寸匹配
if mask_data.shape[:2] != (rows, cols):
print(f"遮罩尺寸不匹配: 图像{rows}x{cols}, 遮罩{mask_data.shape[:2]}")
# 尝试调整遮罩尺寸
try:
from skimage.transform import resize
mask_data_resized = resize(mask_data.astype(float), (rows, cols), order=0, preserve_range=True)
mask_data = mask_data_resized.astype(np.uint8)
print("已调整遮罩尺寸以匹配图像")
except ImportError:
print("需要安装scikit-image来调整遮罩尺寸")
mask_data = None
if mask_data is not None:
# 创建有效像素掩码
valid_mask = mask_data.reshape(-1) > 0
print(f"使用遮罩,有效像素: {valid_mask.sum()}/{len(valid_mask)}")
# 只保留有效像素
X_image = X_image[valid_mask]
else:
print("不使用遮罩,处理所有像素")
return X_image, valid_mask
def predict_image(self, image_data: np.ndarray, mask_data: Optional[np.ndarray] = None) -> Dict[str, np.ndarray]:
"""对图像进行预测"""
# 预处理图像
X_image, valid_mask = self.preprocess_image(image_data, mask_data)
if len(X_image) == 0:
print("没有有效像素进行预测")
return {}
predictions = {}
# 对每个加载的模型进行预测
for model_name, model in self.models.items():
try:
print(f"使用模型 {model_name} 进行预测...")
# 获取标准化器
scaler = self.scalers.get(model_name)
if scaler:
# 标准化数据
X_scaled = scaler.transform(X_image)
else:
X_scaled = X_image
# 进行预测
y_pred = model.predict(X_scaled)
# 重建预测图像
rows, cols, _ = image_data.shape
pred_image = np.full((rows * cols,), np.nan, dtype=np.float32)
if valid_mask is not None:
# 只填充有效像素
pred_image[valid_mask] = y_pred
else:
# 填充所有像素
pred_image[:] = y_pred
# 重塑为图像格式
pred_image = pred_image.reshape(rows, cols)
predictions[model_name] = pred_image
print(f"模型 {model_name} 预测完成")
except Exception as e:
print(f"使用模型 {model_name} 预测时出错: {e}")
continue
return predictions
def save_prediction_results(self, predictions: Dict[str, np.ndarray], image_data: np.ndarray,
mask_data: Optional[np.ndarray] = None) -> List[Path]:
"""保存预测结果"""
saved_files = []
image_name = Path(self.config.image_path).stem
for model_name, pred_image in predictions.items():
try:
# 保存PNG可视化结果
if self.config.save_individual:
output_name = f"{image_name}_{model_name}_prediction.png"
else:
output_name = f"{image_name}_prediction.png"
output_path = self.output_dir / output_name
# 可视化预测结果
self._visualize_prediction(pred_image, model_name, output_path, mask_data)
saved_files.append(output_path)
print(f"保存预测结果PNG: {output_path}")
# 保存ENVI格式的单波段文件
envi_files = self._save_envi_prediction(pred_image, model_name, image_name, image_data.shape)
saved_files.extend(envi_files)
except Exception as e:
print(f"保存模型 {model_name} 的预测结果时出错: {e}")
continue
return saved_files
def _save_envi_prediction(self, pred_image: np.ndarray, model_name: str, image_name: str,
original_shape: tuple) -> List[Path]:
"""保存预测结果为ENVI格式的单波段文件"""
saved_files = []
try:
# 创建输出文件名
if self.config.save_individual:
base_name = f"{image_name}_{model_name}_prediction"
else:
base_name = f"{image_name}_prediction"
bil_path = self.output_dir / f"{base_name}.bil"
hdr_path = self.output_dir / f"{base_name}.hdr"
# 处理NaN值
pred_image_clean = pred_image.copy()
if np.any(np.isnan(pred_image_clean)):
# 用最小值填充NaN
min_val = np.nanmin(pred_image_clean)
pred_image_clean = np.nan_to_num(pred_image_clean, nan=min_val)
# 确保数据类型适合保存
pred_image_clean = pred_image_clean.astype(np.float32)
# 保存为BIL格式 (Band Interleaved by Line)
rows, cols = pred_image_clean.shape
# BIL格式: 按行存储,每个像素的所有波段连续存储
# 对于单波段数据BIL格式就是简单的行优先存储
pred_image_clean.tofile(bil_path)
print(f"保存ENVI数据文件: {bil_path}")
# 生成对应的头文件
self._create_envi_header(hdr_path, rows, cols, 1, bil_path, pred_image_clean.dtype)
saved_files.extend([bil_path, hdr_path])
print(f"保存ENVI头文件: {hdr_path}")
except Exception as e:
print(f"保存ENVI格式预测结果时出错: {e}")
return saved_files
def _create_envi_header(self, hdr_path: Path, rows: int, cols: int, bands: int,
data_path: Path, dtype: np.dtype):
"""创建ENVI格式的头文件"""
try:
# 获取数据类型映射
envi_data_type = self._get_envi_data_type(dtype)
header_content = f"""ENVI
description = {{
Regression prediction result saved as ENVI format}}
samples = {cols}
lines = {rows}
bands = {bands}
header offset = 0
file type = ENVI Standard
data type = {envi_data_type}
interleave = bil
sensor type = Unknown
byte order = 0
"""
with open(hdr_path, 'w', encoding='utf-8') as f:
f.write(header_content)
except Exception as e:
print(f"创建ENVI头文件时出错: {e}")
def _get_envi_data_type(self, dtype: np.dtype) -> int:
"""将numpy数据类型转换为ENVI数据类型"""
dtype_map = {
np.uint8: 1, # 8-bit byte
np.int16: 2, # 16-bit signed integer
np.int32: 3, # 32-bit signed long integer
np.float32: 4, # 32-bit floating point
np.float64: 5, # 64-bit double precision floating point
np.uint16: 12, # 16-bit unsigned integer
np.uint32: 13, # 32-bit unsigned long integer
np.int64: 14, # 64-bit signed long integer
np.uint64: 15, # 64-bit unsigned long integer
}
return dtype_map.get(dtype.type, 4) # 默认float32
def _visualize_prediction(self, pred_image: np.ndarray, model_name: str, output_path: Path,
mask_data: Optional[np.ndarray] = None):
"""可视化预测结果"""
# 处理NaN值
pred_image_clean = pred_image.copy()
if np.any(np.isnan(pred_image_clean)):
# 用最小值填充NaN
min_val = np.nanmin(pred_image_clean)
pred_image_clean = np.nan_to_num(pred_image_clean, nan=min_val)
# 创建图形
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
# 原始预测结果
im1 = ax1.imshow(pred_image_clean, cmap=self.config.colormap)
ax1.set_title(f'{model_name} - Prediction Results')
ax1.axis('off')
plt.colorbar(im1, ax=ax1, shrink=0.8)
# 如果有遮罩,显示遮罩区域
if mask_data is not None:
# 创建RGB图像叠加遮罩
rgb_image = np.zeros((*pred_image_clean.shape, 3))
# 归一化预测结果到0-1
pred_norm = (pred_image_clean - np.nanmin(pred_image_clean)) / (np.nanmax(pred_image_clean) - np.nanmin(pred_image_clean))
pred_norm = np.nan_to_num(pred_norm, nan=0)
# 应用颜色映射
cmap = plt.cm.get_cmap(self.config.colormap)
colored_pred = cmap(pred_norm)
rgb_image = colored_pred[:, :, :3]
# 叠加遮罩边界
from skimage import measure
try:
contours = measure.find_contours(mask_data, 0.5)
for contour in contours:
# 绘制遮罩边界
ax2.plot(contour[:, 1], contour[:, 0], color='red', linewidth=1, alpha=0.7)
except ImportError:
pass
ax2.imshow(rgb_image)
ax2.set_title(f'{model_name} - With Mask Overlay')
else:
# 显示预测结果的直方图
ax2.hist(pred_image_clean.flatten(), bins=50, alpha=0.7, color='blue', edgecolor='black')
ax2.set_title(f'{model_name} - Prediction Distribution')
ax2.set_xlabel('Predicted Value')
ax2.set_ylabel('Frequency')
ax2.axis('off') if mask_data is not None else None
# 添加统计信息
stats_text = f"""
Statistics:
Min: {np.nanmin(pred_image):.4f}
Max: {np.nanmax(pred_image):.4f}
Mean: {np.nanmean(pred_image):.4f}
Std: {np.nanstd(pred_image):.4f}
Valid Pixels: {np.sum(~np.isnan(pred_image))}/{pred_image.size}
"""
fig.suptitle(f'Regression Prediction Results - {model_name}', fontsize=14, fontweight='bold')
plt.figtext(0.02, 0.02, stats_text, fontsize=10, verticalalignment='bottom',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
plt.tight_layout()
# 保存图像
plt.savefig(output_path, dpi=self.config.dpi, bbox_inches='tight')
plt.close()
def predict_single_image(self) -> bool:
"""预测单个图像"""
print(f"开始预测图像: {self.config.image_path}")
# 加载图像和遮罩
image_data, mask_data = self.load_image_and_mask()
if image_data is None:
return False
# 加载模型
if not self.load_models():
return False
# 进行预测
predictions = self.predict_image(image_data, mask_data)
if not predictions:
return False
# 保存结果
saved_files = self.save_prediction_results(predictions, image_data, mask_data)
print(f"预测完成!生成了 {len(saved_files)} 个结果文件")
print(f"输出目录: {self.output_dir}")
return True
def predict_batch_images(self) -> bool:
"""批量预测多个图像"""
print("开始批量预测...")
# 获取图像列表
image_path = Path(self.config.image_path)
if image_path.is_file():
image_files = [image_path]
elif image_path.is_dir():
# 查找目录中的所有支持的图像文件
extensions = ['.bil', '.bsq', '.bip']
image_files = []
for ext in extensions:
image_files.extend(list(image_path.glob(f'*{ext}')))
image_files = sorted(list(set(image_files))) # 去重并排序
else:
print(f"无效的图像路径: {image_path}")
return False
if not image_files:
print("未找到图像文件")
return False
print(f"找到 {len(image_files)} 个图像文件")
# 加载模型(只加载一次)
if not self.load_models():
return False
total_success = 0
for i, image_file in enumerate(image_files, 1):
print(f"\n处理图像 {i}/{len(image_files)}: {image_file.name}")
print("-" * 50)
# 更新配置中的图像路径
self.config.image_path = str(image_file)
# 尝试预测单个图像
if self.predict_single_image():
total_success += 1
else:
print(f"预测失败: {image_file.name}")
print(f"\n批量预测完成!")
print(f"成功预测: {total_success}/{len(image_files)} 个图像")
return total_success > 0
def run_prediction(self) -> bool:
"""运行预测流程"""
start_time = time.time()
try:
if self.config.batch_mode:
success = self.predict_batch_images()
else:
success = self.predict_single_image()
elapsed_time = time.time() - start_time
return success
except Exception as e:
elapsed_time = time.time() - start_time
print(f"错误详情: {e}")
return False
def create_default_config(image_path: str, model_path: str, mask_path: Optional[str] = None,
output_dir: str = "prediction_results", use_mask: bool = True) -> PredictionConfig:
"""创建默认预测配置"""
return PredictionConfig(
image_path=image_path,
mask_path=mask_path,
model_path=model_path,
output_dir=output_dir,
use_mask=use_mask,
batch_mode=False,
colormap='viridis',
dpi=300,
save_individual=True
)
# def main():
# """主函数"""
# parser = argparse.ArgumentParser(description='回归预测工具 - 对高光谱图像进行含量预测')
# # 必需参数
# parser.add_argument('image_path', help='高光谱图像文件路径或目录')
# parser.add_argument('model_path', help='模型文件路径(单个文件或目录)')
# # 可选参数
# parser.add_argument('--mask_path', help='遮罩文件路径')
# parser.add_argument('--output_dir', default='prediction_results', help='输出目录')
# parser.add_argument('--no_mask', action='store_true', help='不使用遮罩')
# parser.add_argument('--batch', action='store_true', help='批量处理模式')
# parser.add_argument('--colormap', default='viridis', help='颜色映射')
# parser.add_argument('--dpi', type=int, default=300, help='输出图像DPI')
# parser.add_argument('--no_individual', action='store_true', help='不保存单个模型结果')
# args = parser.parse_args()
# # 创建配置
# config = PredictionConfig(
# image_path=args.image_path,
# mask_path=args.mask_path,
# model_path=args.model_path,
# output_dir=args.output_dir,
# use_mask=not args.no_mask,
# batch_mode=args.batch,
# colormap=args.colormap,
# dpi=args.dpi,
# save_individual=not args.no_individual
# )
# # 创建预测器并运行
# predictor = RegressionPredictor(config)
# success = predictor.run_prediction()
# return 0 if success else 1
def main():
"""主函数"""
# 必需参数
image_path = r"D:\resonon\RegressionTutorial\7_minute.bil"
model_path = r"E:\code\spectronon\single_classsfication\rgression_method\models"
# 可选参数
mask_path = r"E:\code\spectronon\single_classsfication\rgression_method\data\roi.shp"
output_dir = r"E:\code\spectronon\single_classsfication\output"
mask = True
batch = True
colormap = 'viridis'
dpi = 300
no_individual = False
# 创建配置
config = PredictionConfig(
image_path=image_path,
mask_path=mask_path,
model_path=model_path,
output_dir=output_dir,
use_mask=mask,
batch_mode=batch,
colormap=colormap,
dpi=dpi,
save_individual=not no_individual
)
# 创建预测器并运行
predictor = RegressionPredictor(config)
success = predictor.run_prediction()
return 0 if success else 1
if __name__ == "__main__":
exit(main())