MatCha
概述
MatCha 已在论文 MatCha: Enhancing Visual Language Pretraining with Math Reasoning and Chart Derendering 中提出,作者包括 Fangyu Liu, Francesco Piccinno, Syrine Krichene, Chenxi Pang, Kenton Lee, Mandar Joshi, Yasemin Altun, Nigel Collier, Julian Martin Eisenschlos。
论文的摘要陈述如下:
视觉语言数据,如图表、图表和信息图表,在人类世界中无处不在。然而,最先进的视觉语言模型在这些数据上表现不佳。我们提出了MatCha(数学推理和图表解析预训练)来增强视觉语言模型在联合建模图表/图表和语言数据方面的能力。具体来说,我们提出了几个预训练任务,涵盖了图表解构和数值推理,这些是视觉语言建模中的关键能力。我们从Pix2Struct开始进行MatCha预训练,Pix2Struct是最近提出的图像到文本视觉语言模型。在PlotQA和ChartQA等标准基准测试中,MatCha模型的性能比最先进的方法高出近20%。我们还研究了MatCha预训练在截图、教科书图表和文档图表等领域的迁移效果,并观察到整体改进,验证了MatCha预训练在更广泛的视觉语言任务中的有用性。
模型描述
MatCha 是一个使用 Pix2Struct
架构训练的模型。你可以在 Pix2Struct 文档 中找到更多关于 Pix2Struct
的信息。
MatCha 是 Pix2Struct
架构的视觉问答子集。它将输入的问题渲染在图像上并预测答案。
用法
目前MatCha有6个检查点可用:
google/matcha
: 基础的MatCha模型,用于在下游任务上微调MatChagoogle/matcha-chartqa
: 在ChartQA数据集上微调的MatCha模型。它可以用于回答关于图表的问题。google/matcha-plotqa-v1
: 在PlotQA数据集上微调的MatCha模型。它可以用于回答关于图表的问题。google/matcha-plotqa-v2
: 在PlotQA数据集上微调的MatCha模型。它可以用于回答关于图表的问题。google/matcha-chart2text-statista
: MatCha 模型在 Statista 数据集上进行了微调。google/matcha-chart2text-pew
: MatCha 模型在 Pew 数据集上进行了微调。
在chart2text-pew
和chart2text-statista
上微调的模型更适合于摘要生成,而在plotqa
和chartqa
上微调的模型更适合于问答。
你可以如下使用这些模型(以ChatQA数据集为例):
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
import requests
from PIL import Image
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-chartqa").to(0)
processor = AutoProcessor.from_pretrained("google/matcha-chartqa")
url = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/20294671002019.png"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, text="Is the sum of all 4 places greater than Laos?", return_tensors="pt").to(0)
predictions = model.generate(**inputs, max_new_tokens=512)
print(processor.decode(predictions[0], skip_special_tokens=True))
微调
要微调MatCha,请参考pix2struct的微调笔记本。对于Pix2Struct
模型,我们发现使用Adafactor和余弦学习率调度器微调模型可以更快地收敛:
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
optimizer = Adafactor(self.parameters(), scale_parameter=False, relative_step=False, lr=0.01, weight_decay=1e-05)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=40000)
MatCha 是一个使用 Pix2Struct
架构训练的模型。你可以在 Pix2Struct 文档 中找到更多关于 Pix2Struct
的信息。