69、ncnn学习onnx2ncnn不支持带三维算子相乘gemm/repeat转换方法学习
基本思想:学习不支持带channel维度的Mat相乘
一、测试代码
import cv2
import torch
from torch import nn
import onnxruntime as ort
import ncnn
import numpy as np
class network(nn.Module):def __init__(self):super(network,self).__init__()passdef forward(self,x):a=torch.squeeze(x)b=torch.squeeze(x)y = torch.matmul(a,b)return ynet=network()
mat_data = np.random.rand(1, 3, 6, 6).astype(np.float32)print(mat_data)
img_tensor=torch.Tensor(mat_data)#ncnn: cdhw
print(img_tensor.shape)
result=net(img_tensor)print("torch result :",result,result.shape)
torch.save(net, "example.pt")model = torch.load("example.pt")
model.eval()
torch_out = torch.onnx.export(model, img_tensor, "example.onnx", verbose=True, input_names=["input"],output_names=["output"], opset_version=11)ort_session = or