前沿追踪|基于TorchText的PyTorch文本分类( 二 )


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device
前沿追踪|基于TorchText的PyTorch文本分类在下一步中 , 我们将定义分类的模型 。
class TextSentiment(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)print(model)
前沿追踪|基于TorchText的PyTorch文本分类现在 , 我们将初始化超参数并定义函数以生成训练batch 。
VOCAB_SIZE = len(train_dataset.get_vocab())EMBED_DIM = 32NUN_CLASS = len(train_dataset.get_labels())model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)def generate_batch(batch):label = torch.tensor([entry[0] for entry in batch])text = [entry[1] for entry in batch]offsets = [0] + [len(entry) for entry in text]offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)text = torch.cat(text)return text, offsets, label在下一步中 , 我们将定义用于训练和测试模型的函数 。
def train_func(sub_train_):# 训练模型train_loss = 0train_acc = 0data = http://kandian.youth.cn/index/DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,collate_fn=generate_batch)for i, (text, offsets, cls) in enumerate(data):optimizer.zero_grad()text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)output = model(text, offsets)loss = criterion(output, cls)train_loss += loss.item()loss.backward()optimizer.step()train_acc += (output.argmax(1) == cls).sum().item()# 调整学习率scheduler.step()return train_loss / len(sub_train_), train_acc / len(sub_train_)def test(data_):loss = 0acc = 0data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)for text, offsets, cls in data:text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)with torch.no_grad():output = model(text, offsets)loss = criterion(output, cls)loss += loss.item()acc += (output.argmax(1) == cls).sum().item()return loss / len(data_), acc / len(data_)我们将用5个epoch训练模型 。
N_EPOCHS = 5min_valid_loss = float('inf')criterion = torch.nn.CrossEntropyLoss().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=4.0)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)train_len = int(len(train_dataset) * 0.95)sub_train_, sub_valid_ = \random_split(train_dataset, [train_len, len(train_dataset) - train_len])for epoch in range(N_EPOCHS):start_time = time.time()train_loss, train_acc = train_func(sub_train_)valid_loss, valid_acc = test(sub_valid_)secs = int(time.time() - start_time)mins = secs / 60secs = secs % 60print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')


推荐阅读