使用 Azure 认知服务和分区解释器解释图像字幕(图像到文本)

本笔记本演示了如何使用 SHAP 来解释图像字幕生成模型的输出,即给定一张图像,模型会为该图像生成一个字幕。

在这里,我们使用 Azure 认知服务计算机视觉(COGS CV)图像理解(分析图像)功能 https://azure.microsoft.com/en-us/services/cognitive-services/computer-vision/#features 来获取图像标题。

限制

  1. 为了解释图像标题,我们沿着轴分割图像(即超像素/半分、四分、八分…的分区);另一种方法/未来的改进可能是对图像进行语义分割,而不是轴对齐的分区,并使用段而不是超像素生成SHAP解释。https://github.com/shap/shap/issues/1738

  2. 我们正在使用变压器语言模型(例如 distilbart)来对给定图像和掩码图像标题之间的对齐进行评分,假设外部模型是原始标题模型语言头的良好替代品。通过使用标题模型自身的语言头,我们可以消除这一假设并去除依赖性。(例如,参考 text2text 笔记本示例)。更多详情,请参阅下面的“加载语言模型和分词器”部分。https://github.com/shap/shap/issues/1739

  3. 这里使用了 Azure Cognitive Service 来获取图像说明。为了更快地获取解释,请使用付费服务,因为 API 调用不会受到速率限制。定价详情可以在这里找到:https://azure.microsoft.com/en-us/pricing/details/cognitive-services/computer-vision/

  4. Azure Cognitive Service 对图像大小和文件格式有一定的尺寸限制。API 详情可以在这里找到:https://westcentralus.dev.cognitive.microsoft.com/docs/services/computer-vision-v3-1-ga/operations/56f91f2e778daf14a499f21b。更多详情请参阅“加载数据”部分。

  5. 大图像会减慢SHAP解释的生成速度。因此,任一维度大于500像素的图像将被调整大小。更多详情请参阅“加载数据”部分。

  6. 用于生成解释的评估次数越多,SHAP 运行所需的时间就越长。但是,增加评估次数会增加解释的粒度(通常 300-500 次评估会产生详细的地图,但更少或更多的评估也通常是合理的)。有关更多详细信息,请参阅下面的“使用包装模型和图像遮罩器创建解释器对象”部分。

[9]:
import json
import os
from collections import defaultdict

import numpy as np
import requests
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import shap
from shap.utils.image import (
    add_sample_images,
    check_valid_image,
    display_grid_plot,
    is_empty,
    load_image,
    make_dir,
    resize_image,
    save_image,
)

API 详情

要使用 Azure COGS CV 并运行此笔记本,请获取与您的 Azure COGS CV 订阅相关的 API 密钥和端点,并在下面的代码中替换 <>。建议使用付费服务而不是免费服务,以避免 API 调用受到速率限制,并能快速获得解释。

[10]:
# place your Azure COGS CV subscription API Key and Endpoint below
API_KEY = "<your COGS API access key>"
ENDPOINT = "<endpoint specific to your subscription>"

ANALYZE_URL = ENDPOINT + "/vision/v3.1/analyze"
[11]:
def get_caption(path_to_image):
    """Function to get image caption when path to image file is given.
    Note: API_KEY and ANALYZE_URL need to be defined before calling this function.

    Parameters
    ----------
    path_to_image   : path of image file to be analyzed

    Output
    -------
    image caption

    """
    headers = {
        "Ocp-Apim-Subscription-Key": API_KEY,
        "Content-Type": "application/octet-stream",
    }
    params = {
        "visualFeatures": "Description",
        "language": "en",
    }

    payload = open(path_to_image, "rb").read()

    # get image caption using requests
    response = requests.post(ANALYZE_URL, headers=headers, params=params, data=payload)
    results = json.loads(response.content)

    # return the first caption's text in results description
    caption = results["description"]["captions"][0]["text"]

    return caption

加载数据

‘./test_images/’ 是将被解释的图像文件夹。’./test_images/’ 目录已为您创建,笔记本中显示的示例所需的示例图像已放置在该目录中。

Azure COGS CV 图像要求:

  1. 要获取图像标题的解释,请将需要解释的图像放置在当前笔记本工作目录中名为 ‘test_images’ 的文件夹中。

  2. Azure COGS CV 接受以下文件格式的图像:JPEG (JPG)、PNG、GIF、BMP、JFIF

  3. Azure COGS CV 对图像有大小限制,要求小于 4MB 且最小尺寸为 50x50。因此,以下代码中对大型图像文件进行了重塑,以提高 SHAP 解释的速度并运行 Azure COGS 进行图像字幕生成。如果图像的任一维度(像素大小,像素大小)大于 500,则将该维度调整为最大 500 像素,同时保持原始宽高比调整另一维度。

    注意: 重塑后的图像标题可能与原始图像标题不同。如果需要原始图像的解释,请将下面的’reshape’变量切换为’False’。但请注意,这可能会显著减慢解释过程,或导致Azure COGS CV无法生成标题(SHAP将无法为此图像生成解释。)

[12]:
# directory of images to be explained
DIR = "./test_images/"
# creates or empties directory if it already exists
make_dir(DIR)
add_sample_images(DIR)

# directory for saving resized images
DIR_RESHAPED = "./reshaped_images/"
make_dir(DIR_RESHAPED)

# directory for saving masked images
DIR_MASKED = "./masked_images/"
make_dir(DIR_MASKED)

注意:替换或添加你希望被解释(测试)的图片到 ‘./test_images/’ 文件夹中。**

[13]:
# check if 'test_images' folder exists and if it has any files
if not is_empty(DIR):
    X = []
    reshape = True
    files = [f for f in os.listdir(DIR) if os.path.isfile(os.path.join(DIR, f))]

    for file in files:
        path_to_image = os.path.join(DIR, file)

        # check if file has of any of the following acceptable extensions: JPEG (JPG), PNG, GIF, BMP, JFIF
        if check_valid_image(file):
            print("\nLoading image:", file)
            print("Image caption:", get_caption(path_to_image))
            image = load_image(path_to_image)
            print("Image size:", image.shape)

            # reshaping large image files
            if reshape:
                image, reshaped_file = resize_image(path_to_image, DIR_RESHAPED)
                if reshaped_file:
                    print("Reshaped image caption:", get_caption(reshaped_file))

            X.append(image)
        else:
            print("\nSkipping image due to invalid file extension:", file)

    print("\nNumber of images in test dataset:", len(X))

# delete DIR_RESHAPED if empty
if not os.listdir(DIR_RESHAPED):
    os.rmdir(DIR_RESHAPED)

Loading image: 1.jpg
Image caption: a woman wearing glasses
Image size: (224, 224, 3)

Loading image: 2.jpg
Image caption: a bird on a branch
Image size: (224, 224, 3)

Loading image: 3.jpg
Image caption: a group of horses standing on grass
Image size: (224, 224, 3)

Loading image: 4.jpg
Image caption: a basketball player in a uniform
Image size: (224, 224, 3)

Number of images in test dataset: 4

加载语言模型和分词器

这里使用了Transformer语言模型’distilbart’和分词器来对图像标题进行分词。这使得图像到文本的场景类似于一个多分类问题。’distilbart’用于在原始图像标题和生成的掩码图像标题之间进行对齐评分,即当给出掩码图像标题的上下文时,获得原始图像标题的概率如何变化?(也就是我们正在强制’distilbart’始终为掩码图像生成原始图像标题,并作为过程的一部分获取标题中每个分词单词的对数变化)。

注意: 我们在这里使用 ‘distilbart’ 是因为在实验过程中我们发现它能为图像提供最有意义的解释。我们与其他语言模型如 ‘openaigpt’ 和 ‘distilgpt2’ 进行了比较。请随意探索您选择的其他语言模型并比较结果。

[14]:
# load transformer language model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-6").cuda()

使用包装模型和图像遮罩器创建一个解释器对象

实验解释器对象的各种选项:

  1. mask_value : 图像遮罩器默认使用修复技术进行遮罩(即 mask_value = “inpaint_ns”)。还有其他遮罩选项可用于模糊/修复,例如 “inpaint_telea” 和 “blur(kernel_xsize, kernel_xsize)”。注意:不同的遮罩选项可以生成不同的解释。

  2. max_evals : 对底层模型进行评估以获取SHAP值的次数。推荐的评估次数为300-500,以获得具有超像素有意义粒度的解释。评估次数越多,粒度越细,但也会增加运行时间。默认设置为300次评估。

  3. batch_size : 一次评估的掩码图像数量。默认大小设置为50。

  4. fixed_context : 用于构建分区树的掩码技术,选项为 ‘0’、’1’ 或 ‘None’。’fixed_context = None’ 是生成有意义结果的最佳选项,但它相对于 fixed_context = 0 或 1 来说速度较慢,因为它生成一个完整的分区树。默认选项设置为 ‘None’。

[15]:
# setting values for logging/tracking variables
make_dir(DIR_MASKED)
image_counter = 0
mask_counter = 0
masked_captions = defaultdict(list)
masked_files = defaultdict(list)


# define function f which takes input (masked image) and returns caption for it
def f(x):
    """ "
    Function to return caption for masked image(x).
    """
    global mask_counter

    # saving masked array of RGB values as an image in masked_images directory
    path_to_image = os.path.join(DIR_MASKED, f"{image_counter}_{mask_counter}.png")
    save_image(x, path_to_image)

    # get caption for masked image
    caption = get_caption(path_to_image)
    masked_captions[image_counter].append(caption)
    masked_files[image_counter].append(path_to_image)
    mask_counter += 1

    return caption


# function to take a list of images and parameters such as masking option, max evals etc. and return shap_values objects
def run_masker(
    X,
    mask_value="inpaint_ns",
    max_evals=300,
    batch_size=50,
    fixed_context=None,
    show_grid_plot=False,
    limit_grid=20,
):
    """Function to take a list of images and parameters such max evals etc. and return shap explanations (shap_values) for test images(X).
    Paramaters
    ----------
    X               : list of images which need to be explained
    mask_value      : various masking options for blurring/inpainting such as "inpaint_ns", "inpaint_telea" and "blur(pixel_size, pixel_size)"
    max_evals       : number of evaluations done of the underlying model to get SHAP values
    batch_size      : number of masked images to be evaluated at once
    fixed_context   : masking technqiue used to build partition tree with options of '0', '1' or 'None'
    show_grid_plot  : if set to True, shows grid plot of all masked images and their captions used to generate SHAP values (default: False)
    limit_grid      : limit number of masked images shown (default:20). Change to "all" to show all masked_images.
    Output
    ------
    shap_values_list: list of shap_values objects generated for the images
    """
    global image_counter
    global mask_counter
    shap_values_list = []

    for index in range(len(X)):
        # define a masker that is used to mask out partitions of the input image based on mask_value option
        masker = shap.maskers.Image(mask_value, X[index].shape)

        # wrap model with TeacherForcingLogits class
        wrapped_model = shap.models.TeacherForcingLogits(
            f, similarity_model=model, similarity_tokenizer=tokenizer
        )

        # build a partition explainer with wrapped_model and image masker
        explainer = shap.Explainer(wrapped_model, masker)

        # compute SHAP values - here we use max_evals no. of evaluations of the underlying model to estimate SHAP values
        shap_values = explainer(
            np.array(X[index : index + 1]),
            max_evals=max_evals,
            batch_size=batch_size,
            fixed_context=fixed_context,
        )
        shap_values_list.append(shap_values)

        # output plot
        shap_values.output_names[0] = [
            word.replace("Ġ", "") for word in shap_values.output_names[0]
        ]
        shap.image_plot(shap_values)

        # show grid plot of masked images and their captions
        if show_grid_plot:
            if limit_grid == "all":
                display_grid_plot(
                    masked_captions[image_counter], masked_files[image_counter]
                )
            elif isinstance(limit_grid, int) and limit_grid < len(
                masked_captions[image_counter]
            ):
                display_grid_plot(
                    masked_captions[image_counter][0:limit_grid],
                    masked_files[image_counter][0:limit_grid],
                )
            else:
                print("Enter a valid number for limit_grid parameter.")

        # setting values for next iterations
        mask_counter = 0
        image_counter += 1

    return shap_values_list

测试图像的SHAP解释

[22]:
# run masker with test images dataset (X) and get SHAP explanations for their captions
shap_values = run_masker(X)
Partition explainer: 2it [03:40, 110.24s/it]
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_15_1.svg
Partition explainer: 2it [02:56, 88.21s/it]
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_15_3.svg
Partition explainer: 2it [03:22, 101.31s/it]
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_15_5.svg
Partition explainer: 2it [03:18, 99.35s/it]
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_15_7.svg
[28]:
# SHAP explanation using alternate masking option for inpainting "inpaint_telea"
# displays grid plot of masked images and their captions
# change limit_grid = "all" to show all masked images instead of limiting to 24 masked images
shap_values = run_masker(
    X[2:3], mask_value="inpaint_telea", show_grid_plot=True, limit_grid=24
)
Partition explainer: 2it [03:51, 115.99s/it]
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_16_1.svg
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_16_2.svg
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_16_3.svg
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_16_4.svg
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_16_5.svg
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_16_6.svg
../../../_images/example_notebooks_image_examples_image_captioning_Image_Captioning_using_Azure_Cognitive_Services_16_7.svg