import torch.nn.functional as F
def crop_tensor_by_height_width(tensor, height_crop, width_crop):assert len(tensor.shape) == 4, '输入的tensor应为4维'assert height_crop > 0 and width_crop > 0, 'crop应该大于0'height_extra = 0width_extra = 0if height_crop % 2 != 0:height_extra = 1if width_crop % 2 != 0:width_extra = 1lower_bound_height_crop = height_crop // 2lower_bound_width_crop = width_crop // 2original_height, original_width = tensor.shape[2], tensor.shape[3]upper_width_height_crop = original_height - height_crop // 2 - height_extraupper_width_width_crop = original_width - width_crop // 2 - width_extrareturn tensor[:, :, lower_bound_height_crop:upper_width_height_crop, lower_bound_width_crop:upper_width_width_crop]def crop_or_pad_tensor_by_height_width(tensor, height_crop, width_crop, pad_value=0):'''裁剪或扩展Tensor在高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。正数表示扩展(用0填充),负数表示裁剪。参数:tensor (torch.Tensor): 输入的4维张量,形状为 (batch_size, channels, height, width)height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1pad_value (float or int): 填充时使用的值,默认为0返回:cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量'''assert len(tensor.shape) == 4, '输入的tensor应为4维'original_height, original_width = tensor.shape[2], tensor.shape[3]height_to_remove_from_bottom = min(original_height, -height_crop) if height_crop < 0 else 0width_to_remove_from_right = min(original_width, -width_crop) if width_crop < 0 else 0pad_bottom = abs(height_crop) if height_crop > 0 else 0pad_right = abs(width_crop) if width_crop > 0 else 0padded_tensor = F.pad(tensor, pad=(0, pad_right, 0, pad_bottom), mode='constant', value=pad_value)if height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :-width_to_remove_from_right]elif height_to_remove_from_bottom > 0:cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :]elif width_to_remove_from_right > 0:cropped_or_padded_tensor = padded_tensor[:, :, :, :-width_to_remove_from_right]else:cropped_or_padded_tensor = padded_tensorreturn cropped_or_padded_tensordef crop_or_pad_tensor_by_depth_height_width(tensor, depth_crop, height_crop, width_crop, pad_value=0):'''裁剪或扩展Tensor在深度(仅最后一个)、高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。正数表示扩展(用0填充),负数表示裁剪。参数:tensor (torch.Tensor): 输入的5维张量,形状为 (batch_size, channels, depth, height, width)depth_crop (int): 深度方向上最后一个要裁剪或扩展的数量,默认为1height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1pad_value (float or int): 填充时使用的值,默认为0返回:cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量'''assert len(tensor.shape) == 5, '输入的tensor应为5维'original_depth, original_height, original_width = tensor.shape[2], tensor.shape[3], tensor.shape[4]depth_to_remove_from_end = min(original_depth, -depth_crop) if depth_crop < 0 else 0height_to_remove_from_bottom = min(original_height, -height_crop) if height_crop < 0 else 0width_to_remove_from_right = min(original_width, -width_crop) if width_crop < 0 else 0pad_depth = abs(depth_crop) if depth_crop > 0 else 0pad_bottom = abs(height_crop) if height_crop > 0 else 0pad_right = abs(width_crop) if width_crop > 0 else 0padded_tensor = F.pad(tensor, pad=(0, pad_right, 0, pad_bottom, 0, pad_depth), mode='constant', value=pad_value)if depth_to_remove_from_end > 0 and height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :-height_to_remove_from_bottom,:-width_to_remove_from_right]elif depth_to_remove_from_end > 0 and height_to_remove_from_bottom > 0:cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :-height_to_remove_from_bottom, :]elif depth_to_remove_from_end > 0 and width_to_remove_from_right > 0:cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :, :-width_to_remove_from_right]elif height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:cropped_or_padded_tensor = padded_tensor[:, :, :, :-height_to_remove_from_bottom, :-width_to_remove_from_right]elif depth_to_remove_from_end > 0:cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :, :]elif height_to_remove_from_bottom > 0:cropped_or_padded_tensor = padded_tensor[:, :, :, :-height_to_remove_from_bottom, :]elif width_to_remove_from_right > 0:cropped_or_padded_tensor = padded_tensor[:, :, :, :, :-width_to_remove_from_right]else:cropped_or_padded_tensor = padded_tensorreturn cropped_or_padded_tensor