金融市场中的NLP——情感分析( 四 )
- BERT:bert-base-uncased
- ALBERT:albert-base-v2
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3)def train_bert(model, tokenizer)# 移动模型到GUP/CPU设备device = 'cuda:0' if torch.cuda.is_available() else 'cpu'model = model.to(device)# 将数据加载到SimpleDataset(自定义数据集类)train_ds = SimpleDataset(x_train, y_train)valid_ds = SimpleDataset(x_valid, y_valid)# 使用DataLoader批量加载数据集中的数据train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)valid_loader = torch.utils.data.DataLoader(valid_ds, batch_size=batch_size, shuffle=False)# 优化器和学习率衰减num_total_opt_steps = int(len(train_loader) * num_epochs)optimizer = AdamW_HF(model.parameters(), lr=learning_rate, correct_bias=False)scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_total_opt_steps*warm_up_proportion, num_training_steps=num_total_opt_steps)# PyTorch scheduler# 训练model.train()# Tokenizer 参数param_tk = {'return_tensors': "pt",'padding': 'max_length','max_length': max_seq_length,'add_special_tokens': True,'truncation': True}# 初始化best_f1 = 0.early_stop = 0train_losses = []valid_losses = []for epoch in tqdm(range(num_epochs), desc="Epoch"):# print('================epoch {}==============='.format(epoch+1))train_loss = 0.for i, batch in enumerate(train_loader):# 传输到设备x_train_bt, y_train_bt = batchx_train_bt = tokenizer(x_train_bt, **param_tk).to(device)y_train_bt = torch.tensor(y_train_bt, dtype=torch.long).to(device)# 重设梯度optimizer.zero_grad()# 前馈预测loss, logits = model(**x_train_bt, labels=y_train_bt)# 反向传播loss.backward()# 损失train_loss += loss.item() / len(train_loader)# 梯度剪切torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)# 更新权重和学习率optimizer.step()scheduler.step()train_losses.append(train_loss)# 评估模式model.eval()# 初始化val_loss = 0.y_valid_pred = np.zeros((len(y_valid), 3))with torch.no_grad():for i, batch in enumerate(valid_loader):# 传输到设备x_valid_bt, y_valid_bt = batchx_valid_bt = tokenizer(x_valid_bt, **param_tk).to(device)y_valid_bt = torch.tensor(y_valid_bt, dtype=torch.long).to(device)loss, logits = model(**x_valid_bt, labels=y_valid_bt)val_loss += loss.item() / len(valid_loader)valid_losses.append(val_loss)# 计算指标acc, f1 = metric(y_valid, np.argmax(y_valid_pred, axis=1))# 如果改进了 , 保存模型 。 如果没有 , 那就提前停止if best_f1 < f1:early_stop = 0best_f1 = f1else:early_stop += 1print('epoch: %d, train loss: %.4f, valid loss: %.4f, acc: %.4f, f1: %.4f, best_f1: %.4f, last lr: %.6f' %(epoch+1, train_loss, val_loss, acc, f1, best_f1, scheduler.get_last_lr()[0]))if device == 'cuda:0':torch.cuda.empty_cache()# 如果达到耐心数 , 提前停止if early_stop >= patience:break# 返回训练模式model.train()return model
评估首先 , 输入数据以8:2分为训练组和测试集 。 测试集保持不变 , 直到所有参数都固定下来 , 并且每个模型只使用一次 。 由于数据集不用于计算交叉集 , 因此验证集不用于计算 。 此外 , 为了克服数据集不平衡和数据集较小的问题 , 采用分层K-Fold交叉验证进行超参数整定 。
推荐阅读
- 柔性电子市场广阔,领头羊柔宇科技获更多关注
- Eyeware Beam使用iPhone追踪玩家在游戏中的眼睛运动
- 又爆炸!联电科技传来一声巨响,或把8 英寸晶圆市场"炸"了
- 线下市场彻底“乱了”!小米宣布新规!华为捆绑加价行为迎争议
- 腾讯游戏发起对华为的挑战,或因后者对国内手机市场的影响力大跌
- 华为P50 Pro渲染图曝光:曲面瀑布屏
- 苹果中国区下架近5万款游戏应用,手游市场面临大洗牌
- 转转:iPhone 12热销 二手市场5G手机交易看涨
- 市场|iPhoneX用户集中卖手机?转转Q4手机行情:iPhone12引领5G换机潮
- OPPO西欧出货量去年增长三倍 高端市场成头部厂商必争之地