悟空收录网

在少样本学习中,用SetFit进行文本分类


文章编号:395 / 更新时间:2023-11-30 18:00:20 / 浏览:

在监督(Supervised)机器学习中,大量数据集被用于模型训练,以便磨练模型能够做出精确预测的能力。在完成训练过程之后,我们便可以利用测试数据,来获得模型的预测结果。然而,这种传统的监督学习方法存在着一个显著缺点:它需要大量无差错的训练数据集。但是并非所有领域都能够提供此类无差错数据集。因此,“少样本学习”的概念应运而生。

在深入研究SentenceTransformerfine-tuning(SetFit)之前,我们有必要简要地回顾一下自然语言处理(NaturalLanguageProcessing,NLP)的一个重要方面,也就是:“少样本学习”。

少样本学习是指:使用有限的训练数据集,来训练模型。模型可以从这些被称为支持集的小集合中获取知识。此类学习旨在教会少样本模型,辨别出训练数据中的相同与相异之处。例如,我们并非要指示模型将所给图像分类为猫或狗,而是指示它掌握各种动物之间的共性和区别。可见,这种方法侧重于理解输入数据中的相似点和不同点。因此,它通常也被称为元学习(meta-learning)、或是从学习到学习(learning-to-learn)。

值得一提的是,少样本学习的支持集,也被称为k向(k-way)n样本(n-shot)学习。其中“k”代表支持集里的类别数。例如,在二分类(binaryclassification)中,k等于2。而“n”表示支持集中每个类别的可用样本数。例如,如果正分类有10个数据点,而负分类也有10个数据点,那么n就等于10。总之,这个支持集可以被描述为双向10样本学习。

既然我们已经对少样本学习有了基本的了解,下面让我们通过使用SetFit进行快速学习,并在实际应用中对电商数据集进行文本分类。

由HuggingFace和英特尔实验室的团队联合开发的SetFit,是一款用于少样本照片分类的开源工具。你可以在项目库链接--https://github.com/huggingface/setfit?ref=hackernoon.com中,找到关于SetFit的全面信息。

SetFit的训练速度非常快,效率也极高。与GPT-3和T-FEW等大模型相比,其性能极具竞争力。请参见下图:

SetFit与T-Few3B模型的比较

如下图所示,SetFit在少样本学习方面的表现优于RoBERTa。

为了便于采用少样本的训练方法,我们将从四个类别中各选择八个样本,从而得到总共32个训练样本。而其余样本则将留作测试之用。简言之,我们在此使用的支持集是4向8样本学习。下图展示的是自定义电商数据集的示例:

自定义电商数据集样本

我们采用名为“all-mpnet-base-v2”的SentenceTransformers预训练模型,将文本数据转换为各种向量嵌入。该模型可以为输入文本,生成维度为768的向量嵌入。

如下命令所示,我们将通过在conda环境(是一个开源的软件包管理系统和环境管理系统)中安装所需的软件包,来开始SetFit的实施。

!pip3installSetFit!pip3installsklearn!pip3installtransformers!pip3installsentence-transformers

安装完软件包后,我们便可以通过如下代码加载数据集了。

fromdatasetsimportload_datasetdataset=load_dataset('csv',data_files={"train":'E_Commerce_Dataset_Train.csv',"test":'E_Commerce_Dataset_Test.csv'})

我们来参照下图,看看训练样本和测试样本数。

训练和测试数据

Encoded_Product=le.fit_transform(dataset["train"]['Label'])dataset["train"]=dataset["train"].remove_columns("Label").add_column("Label",Encoded_Product).cast(dataset["train"].features)Encoded_Product=le.fit_transform(dataset["test"]['Label'])dataset["test"]=dataset["test"].remove_columns("Label").add_column("Label",Encoded_Product).cast(dataset["test"].features)

下面,我们将初始化SetFit模型和句子转换器(sentence-transformers)模型。

fromsetfitimportSetFitModel,SetFitTrainerfromsentence_transformers.lossesimportCosineSimilarityLossmodel_id="sentence-transformers/all-mpnet-base-v2"model=SetFitModel.from_pretrained(model_id)trainer=SetFitTrainer(model=model,train_dataset=dataset["train"],eval_dataset=dataset["test"],loss_class=CosineSimilarityLoss,metric="accuracy",batch_size=64,num_iteratinotallow=20,num_epochs=2,column_mapping={"Text":"text","Label":"label"})

初始化完成两个模型后,我们现在便可以调用训练程序了。

trainer.train()

在完成了2个训练轮数(epoch)后,我们将在eval_dataset上,对训练好的模型进行评估。

trainer.evaluate()

经测试,我们的训练模型的最高准确率为87.5%。虽然87.5%的准确率并不算高,但是毕竟我们的模型只用了32个样本进行训练。也就是说,考虑到数据集规模的有限性,在测试数据集上取得87.5%的准确率,实际上是相当可观的。

此外,SetFit还能够将训练好的模型,保存到本地存储器中,以便后续从磁盘加载,用于将来的预测。

trainer.model._save_pretrained(save_directory="SetFit_ECommerce_Output/")model=SetFitModel.from_pretrained("SetFit_ECommerce_Output/",local_files_notallow=True)

如下代码展示了根据新的数据进行的预测结果:

至此,相信您已经基本掌握了“少样本学习”的概念,以及如何使用SetFit来进行文本分类等应用。当然,为了获得更深刻的理解,我强烈建议您选择一个实际场景,创建一个数据集,编写对应的代码,并将该过程延展到零样本学习、以及单样本学习上。

北京市海淀区中关村南1条甲1号ECO中科爱克大厦6-7层

北京市公安局海淀分局备案编号:110108002980号营业执照

我关注的话题
相关标签: 机器学习少样本学习SetFit

本文地址:http://www.wkong.net/article-395.html

上一篇:Spring到底是如何解决循环依赖问题的?​...
下一篇:滴滴崩了18小时,事故危机谁买单...

发表评论

温馨提示

做上本站友情链接,在您站上点击一次,即可自动收录并自动排在本站第一位!
<a href="http://www.wkong.net/" target="_blank">悟空收录网</a>