在少样本学习中,用SetFit进行文本分类 译文

2023-11-27
关注

译者 | 陈峻

审校 | 重楼

在本文中,我将向您介绍“少样本(Few-shot)学习”的相关概念,并重点讨论被广泛应用于文本分类的SetFit方法。

传统的机器学习(ML)

在监督(Supervised)机器学习中,大量数据集被用于模型训练,以便磨练模型能够做出精确预测的能力。在完成训练过程之后,我们便可以利用测试数据,来获得模型的预测结果。然而,这种传统的监督学习方法存在着一个显著缺点:它需要大量无差错的训练数据集。但是并非所有领域都能够提供此类无差错数据集。因此,“少样本学习”的概念应运而生。

在深入研究Sentence Transformer fine-tuning(SetFit)之前,我们有必要简要地回顾一下自然语言处理(Natural Language Processing,NLP)的一个重要方面,也就是:“少样本学习”。

少样本学习

少样本学习是指:使用有限的训练数据集,来训练模型。模型可以从这些被称为支持集的小集合中获取知识。此类学习旨在教会少样本模型,辨别出训练数据中的相同与相异之处。例如,我们并非要指示模型将所给图像分类为猫或狗,而是指示它掌握各种动物之间的共性和区别。可见,这种方法侧重于理解输入数据中的相似点和不同点。因此,它通常也被称为元学习(meta-learning)、或是从学习到学习(learning-to-learn)。

值得一提的是,少样本学习的支持集,也被称为k向(k-way)n样本(n-shot)学习。其中“k”代表支持集里的类别数。例如,在二分类(binary classification)中,k 等于 2。而“n”表示支持集中每个类别的可用样本数。例如,如果正分类有10个数据点,而负分类也有10个数据点,那么 n就等于10。总之,这个支持集可以被描述为双向10样本学习。

既然我们已经对少样本学习有了基本的了解,下面让我们通过使用SetFit进行快速学习,并在实际应用中对电商数据集进行文本分类。

SetFit架构

由Hugging Face和英特尔实验室的团队联合开发的SetFit,是一款用于少样本照片分类的开源工具。你可以在项目库链接--https://github.com/huggingface/setfit?ref=hackernoon.com中,找到关于SetFit的全面信息。

就输出而言,SetFit仅用到了客户评论(Customer Reviews,CR)情感分析数据集里、每个类别的八个标注示例。其结果就能够与由三千个示例组成的完整训练集上,经调优的RoBERTa Large的结果相同。值得强调的是,就体积而言,经微优的RoBERTa模型比SetFit模型大三倍。下图展示的是SetFit架构:

图片来源:https://www.sbert.net/docs/training/overview.html?ref=hackernoon.com

用SetFit实现快速学习

SetFit的训练速度非常快,效率也极高。与GPT-3和T-FEW等大模型相比,其性能极具竞争力。请参见下图:

SetFit与T-Few 3B模型的比较

如下图所示,SetFit在少样本学习方面的表现优于RoBERTa。

SetFit与RoBERT的比较,图片来源:https://huggingface.co/blog/setfit?ref=hackernoon.com

数据集

下面,我们将用到由四个不同类别组成的独特电商数据集,它们分别是:书籍、服装与配件、电子产品、以及家居用品。该数据集的主要目的是将来自电商网站的产品描述归类到指定的标签下。

为了便于采用少样本的训练方法,我们将从四个类别中各选择八个样本,从而得到总共32个训练样本。而其余样本则将留作测试之用。简言之,我们在此使用的支持集是4向8样本学习。下图展示的是自定义电商数据集的示例:

自定义电商数据集样本

我们采用名为“all-mpnet-base-v2”的Sentence Transformers预训练模型,将文本数据转换为各种向量嵌入。该模型可以为输入文本,生成维度为768的向量嵌入。

如下命令所示,我们将通过在conda环境(是一个开源的软件包管理系统和环境管理系统)中安装所需的软件包,来开始SetFit的实施。

!pip3 install SetFit 
!pip3 install sklearn 
!pip3 install transformers 
!pip3 install sentence-transformers

安装完软件包后,我们便可以通过如下代码加载数据集了。

from datasets import load_dataset
dataset = load_dataset('csv', data_files={
"train": 'E_Commerce_Dataset_Train.csv',
"test": 'E_Commerce_Dataset_Test.csv'
})

我们来参照下图,看看训练样本和测试样本数。

训练和测试数据

我们使用sklearn软件包中的LabelEncoder,将文本标签转换为编码标签。

from sklearn.preprocessing import LabelEncoder 
le = LabelEncoder()

通过LabelEncoder,我们将对训练和测试数据集进行编码,并将编码后的标签添加到数据集的“标签”列中。请参见如下代码:

Encoded_Product = le.fit_transform(dataset["train"]['Label']) 
dataset["train"] = dataset["train"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["train"].features)
Encoded_Product = le.fit_transform(dataset["test"]['Label']) 
dataset["test"] = dataset["test"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["test"].features)

下面,我们将初始化SetFit模型和句子转换器(sentence-transformers)模型。

from setfit import SetFitModel, SetFitTrainer 
from sentence_transformers.losses import CosineSimilarityLoss
model_id = "sentence-transformers/all-mpnet-base-v2" 
model = SetFitModel.from_pretrained(model_id)
trainer = SetFitTrainer( 
 model=model,
 train_dataset=dataset["train"],
 eval_dataset=dataset["test"],
 loss_class=CosineSimilarityLoss,
 metric="accuracy",
 batch_size=64,
 num_iteratinotallow=20,
 num_epochs=2,
 column_mapping={"Text": "text", "Label": "label"}
)

初始化完成两个模型后,我们现在便可以调用训练程序了。

trainer.train()

在完成了2个训练轮数(epoch)后,我们将在eval_dataset上,对训练好的模型进行评估。

trainer.evaluate()

经测试,我们的训练模型的最高准确率为87.5%。虽然87.5%的准确率并不算高,但是毕竟我们的模型只用了32个样本进行训练。也就是说,考虑到数据集规模的有限性,在测试数据集上取得87.5%的准确率,实际上是相当可观的。

此外,SetFit还能够将训练好的模型,保存到本地存储器中,以便后续从磁盘加载,用于将来的预测。

trainer.model._save_pretrained(save_directory="SetFit_ECommerce_Output/")
model=SetFitModel.from_pretrained("SetFit_ECommerce_Output/", local_files_notallow=True)

如下代码展示了根据新的数据进行的预测结果:

input = ["Campus Sutra Men's Sports Jersey T-Shirt Cool-Gear: Our Proprietary Moisture Management technology. Helps to absorb and evaporate sweat quickly. Keeps you Cool & Dry. Ultra-Fresh: Fabrics treated with Ultra-Fresh Antimicrobial Technology. Ultra-Fresh is a trademark of (TRA) Inc, Ontario, Canada. Keeps you odour free."]
output = model(input)

可见,其预测输出为1,而标签的LabelEncoded值为“服装与配件”。由于传统的AI模型需要大量的训练资源(包括时间和数据),才能有稳定水准的输出。而我们的模型与之相比,既准确又高效。

至此,相信您已经基本掌握了“少样本学习”的概念,以及如何使用SetFit来进行文本分类等应用。当然,为了获得更深刻的理解,我强烈建议您选择一个实际场景,创建一个数据集,编写对应的代码,并将该过程延展到零样本学习、以及单样本学习上。

译者介绍

陈峻(Julian Chen),51CTO社区编辑,具有十多年的IT项目实施经验,善于对内外部资源与风险实施管控,专注传播网络与信息安全知识与经验。

原文标题:Mastering Few-Shot Learning with SetFit for Text Classification,作者:Shyam Ganesh S)


您觉得本篇内容如何
评分

相关产品

Honeywell 霍尼韦尔智能工业 在线/便携烟气分析仪专用传感器 气体传感器

CO 传感器;SO2传感器;NO2 传感器;NO传感器;氧气传感器

南方泰科 TGM 压力传感器

TGM是一款SOP8封装的压阻式MEMS压力传感器,其压力传感器芯片封装在 SOP8 塑封壳内。在传感器压力量程内,当用固定电压供电时,传感器产生毫伏输出电压,正比于输入压力。压力传感器芯片为绝压,可提供不同的压力量程的SOP8 压力传感器。

Huba Control 富巴 525系列 压力传感器

525系列压力传感器采用集公司20多年研发经验的陶瓷压力传感器芯片技术。该系列压力传感器可选压力范围大,电气连接形式多。最小量程为50mbar。大批量使用具有很好的性价比。

Cubic 四方光电 PM3009BP 室外粉尘传感器

PM3009BP是一款专门针对餐饮油烟监测的油烟传感器,其采用旁流采样方式,自带除水雾装置,结合智能颗粒物识别算法,确保传感器能够快速准确的检测油烟浓度的变化,同时创新的镜头自清洁技术的应用,能够长效防护传感器油烟污染,大幅度延长传感器的使用寿命。

Winsen 炜盛科技 MH-410D 红外CO2气体传感器 红外传感器

MH-410D红外气体传感器是通用型、智能型、微型传感器,该红外传感器利用非色散红外(NDIR)原理对空气中存在的CO2进行探测,具有很好的选择性,无氧气依赖性,性能稳定、寿命长。内置温度补偿。该红外传感器是通过将成熟的红外吸收气体检测技术与微型机械加工、精良电路设计紧密结合而制作出的小巧型高性能红外传感器。该红外传感器可广泛应用于暖通制冷与室内空气质量监控、工业过程及安全防护监控、农业及畜牧业生产过程监控。

Alliance 莱恩&联众传感线缆 Aurora Tool Cable 医疗电线 医疗线缆

用于连接两个5DOF传感器或一个6DOF传感器的电缆。 可重复使用 用于电磁跟踪系统

RAYCOH 锐科智能 30GM系列 IO-Link输出 2EP-IO,IUEP-IO 超声波测距传感器和接近开关

RAYCOH 锐科智能30GM系列 IO-Link输出 超声波线性位置传感器和开关

评论

您需要登录才可以回复|注册

提交评论

选型助手

这家伙很懒,什么描述也没留下

关注

点击进入下一篇

2023年最新《瓦森纳协定》

提取码
复制提取码
点击跳转至百度网盘