返回最大值的index pytorch
catbird233 人气:0返回最大值的index
import torch a=torch.tensor([[.1,.2,.3], [1.1,1.2,1.3], [2.1,2.2,2.3], [3.1,3.2,3.3]]) print(a.argmax(dim=1)) print(a.argmax())
输出:
tensor([ 2, 2, 2, 2])
tensor(11)
pytorch 找最大值
题意:使用神经网络实现,从数组中找出最大值。
提供数据:两个 csv 文件,一个存训练集:n 个 m 维特征自然数数据,另一个存每条数据对应的 label ,就是每条数据中的最大值。
这里将随机构建训练集:
#%% import numpy as np import pandas as pd import torch import random import torch.utils.data as Data import torch.nn as nn import torch.optim as optim def GetData(m, n): dataset = [] for j in range(m): max_v = random.randint(0, 9) data = [random.randint(0, 9) for i in range(n)] dataset.append(data) label = [max(dataset[i]) for i in range(len(dataset))] data_list = np.column_stack((dataset, label)) data_list = data_list.astype(np.float32) return data_list #%% # 数据集封装 重载函数len, getitem class GetMaxEle(Data.Dataset): def __init__(self, trainset): self.data = trainset def __getitem__(self, index): item = self.data[index] x = item[:-1] y = item[-1] return x, y def __len__(self): return len(self.data) # %% 定义网络模型 class SingleNN(nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(SingleNN, self).__init__() self.hidden = nn.Linear(n_feature, n_hidden) self.relu = nn.ReLU() self.predict = nn.Linear(n_hidden, n_output) def forward(self, x): x = self.hidden(x) x = self.relu(x) x = self.predict(x) return x def train(m, n, batch_size, PATH): # 随机生成 m 个 n 个维度的训练样本 data_list =GetData(m, n) dataset = GetMaxEle(data_list) trainset = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) net = SingleNN(n_feature=10, n_hidden=100, n_output=10) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # total_epoch = 100 for epoch in range(total_epoch): for index, data in enumerate(trainset): input_x, labels = data labels = labels.long() optimizer.zero_grad() output = net(input_x) # print(output) # print(labels) loss = criterion(output, labels) loss.backward() optimizer.step() # scheduled_optimizer.step() print(f"Epoch {epoch}, loss:{loss.item()}") # %% 保存参数 torch.save(net.state_dict(), PATH) #测试 def test(m, n, batch_size, PATH): data_list = GetData(m, n) dataset = GetMaxEle(data_list) testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) dataiter = iter(testloader) input_x, labels = dataiter.next() net = SingleNN(n_feature=10, n_hidden=100, n_output=10) net.load_state_dict(torch.load(PATH)) outputs = net(input_x) _, predicted = torch.max(outputs, 1) print("Ground_truth:",labels.numpy()) print("predicted:",predicted.numpy()) if __name__ == "__main__": m = 1000 n = 10 batch_size = 64 PATH = './max_list.pth' train(m, n, batch_size, PATH) test(m, n, batch_size, PATH)
初始的想法是使用全连接网络+分类来实现, 但是结果不尽人意,主要原因:不同类别之间的样本量差太大,几乎90%都是最大值。
比如代码中随机构建 10 个 0~9 的数字构成一个样本[2, 3, 5, 8, 9, 5, 3, 9, 3, 6], 该样本标签是9。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持。
加载全部内容