C++ ncnn模型验证精度实现代码
樊城 人气:0验证ncnn模型的精度
1、进行pth模型的验证
得到ncnn模型的顺序为:.pth–>.onnx–>ncnn
.pth的精度验证如下:
如进行的是二分类:
model = init_model(model, data_cfg, device=device, mode='eval') ###.pth转.onnx模型 # #--- # input_names = ["x"] # output_names = ["y"] # inp = torch.randn(1, 3, 256, 128) ##错误示例 inp = np.full((1, 3, 160, 320), 0.5).astype(np.float) #(160,320) = (h,w) inp = torch.FloatTensor(inp) out = model(inp) print(out)
没有经过softmax层,out输出为±1的两个值。
2、转为onnx后的精度验证
sess = onnxruntime.InferenceSession("G:\\pycharm_pytorch171\\pytorch_classification\\main\\sim.onnx", providers=["CUDAExecutionProvider"]) # use gpu input_name = sess.get_inputs()[0].name print("input_name: ", input_name) output_name = sess.get_outputs()[0].name print("output_name: ", output_name) # test_images = torch.rand([1, 3, 256, 128]) test_images = np.full((1, 3, 160, 320), 0.5).astype(np.float) #(160,320) = (h,w) test_images = torch.FloatTensor(test_images) print("test_image", test_images) prediction = sess.run([output_name], {input_name: test_images.numpy()}) print(prediction)
3、ncnn精度验证
首先保证mean、norm输出的值与onnx保持一致,因为onnx直接输入值0.5,ncnn模型经过mean、norm计算后的结果与0.5一致就行。
然后就是ncnn模型的计算输出
- 查看输出结果是否是0.5,首先得将输入值1给到img
```cpp constexpr int w = 320; constexpr int h = 160; float cbuf[h][w]; cv::Mat img(h, w, CV_8UC3,(float *)cbuf); //BYTE* iPtr = new BYTE[128 * 256 * 3]; BYTE* iPtr = new BYTE[h * w * 3]; for (int i = 0; i < h; i++) { for (int j = 0; j < w; j++) { for (int k = 0; k < 3; k++) { //iPtr[i * 256 * 3 + j * 3 + k] = img.at<cv::Vec3f>(i, j)[k]; img.at<cv::Vec3b>(i, j)[k] = 1; } } } ``` - 经过上面的赋值,通过了mean、norm计算后,得到的结果进行查看,值为0.5则正确转换。得到的结果送入下面的代码进行输出。 ncnn结果为mat,因此采用该方法进行遍历查看。 ```cpp //输出ncnn mat void ncnn_mat_print(const ncnn::Mat& m) { for (int q = 0; q < m.c; q++) { const float* ptr = m.channel(q); for (int y = 0; y < m.h; y++) { for (int x = 0; x < m.w; x++) { printf("%f ", ptr[x]); } ptr += m.w; printf("\n"); } printf("------------------------\n"); } } ``` 将mat给到模型进行推理得到结果。
4、结果确认
一般情况下,pth模型与onnx模型结果相差不大,ncnn会有点点损失,千分位上的损失,这样精度基本上是一致的。
若不一致,看哪一步结果相差太大,如果是ncnn这一步相差太大,检查是否是值输入有问题,或者是输入的(h,w)弄反了。
加载全部内容