import types
from ..utils._exceptions import InvalidMaskerError
from ._masker import Masker
[文档]
class Composite(Masker):
"""This merges several maskers for different inputs together into a single composite masker.
This is not yet implemented.
"""
[文档]
def __init__(self, *maskers):
self.maskers = maskers
self.arg_counts = []
self.total_args = 0
self.text_data = False
self.image_data = False
all_have_clustering = True
for masker in self.maskers:
all_args = masker.__call__.__code__.co_argcount
if masker.__call__.__defaults__ is not None: # in case there are no kwargs
kwargs = len(masker.__call__.__defaults__)
else:
kwargs = 0
num_args = all_args - kwargs - 2
self.arg_counts.append(num_args) # -2 is for the self and mask arg
self.total_args += num_args
if not hasattr(masker, "clustering"):
all_have_clustering = False
self.text_data = self.text_data or getattr(masker, "text_data", False)
self.image_data = self.image_data or getattr(masker, "image_data", False)
if all_have_clustering:
self.clustering = types.MethodType(joint_clustering, self)
[文档]
def shape(self, *args):
"""Compute the shape of this masker as the sum of all the sub masker shapes."""
assert len(args) == self.total_args, "The number of passed args is incorrect!"
rows = None
cols = 0
pos = 0
for i, masker in enumerate(self.maskers):
if callable(masker.shape):
shape = masker.shape(*args[pos:pos+self.arg_counts[i]])
else:
shape = masker.shape
if rows is None:
rows = shape[0]
else:
assert shape[1] == 0 or rows == shape[0], "All submaskers of a Composite masker must return the same number of rows!"
cols += shape[1]
pos += self.arg_counts[i]
return rows, cols
[文档]
def mask_shapes(self, *args):
"""The shape of the masks we expect."""
out = []
pos = 0
for i, masker in enumerate(self.maskers):
out.extend(masker.mask_shapes(*args[pos:pos+self.arg_counts[i]]))
return out
def __call__(self, mask, *args):
mask = self._standardize_mask(mask, *args)
assert len(args) == self.total_args, "The number of passed args is incorrect!"
# compute all the shapes and confirm they align
arg_pos = 0
shapes = []
num_rows = None
for i, masker in enumerate(self.maskers):
masker_args = args[arg_pos:arg_pos+self.arg_counts[i]]
if callable(masker.shape):
shapes.append(masker.shape(*masker_args))
else:
shapes.append(masker.shape)
if num_rows is None:
num_rows = shapes[-1][0]
elif num_rows == 1 and shapes[-1][0] is not None:
num_rows = shapes[-1][0]
if shapes[-1][0] != num_rows and shapes[-1][0] != 1 and shapes[-1][0] is not None:
raise InvalidMaskerError("The composite masker can only join together maskers with a compatible number of background rows!")
arg_pos += self.arg_counts[i]
# call all the submaskers and combine their outputs
arg_pos = 0
mask_pos = 0
masked = []
for i, masker in enumerate(self.maskers):
masker_args = args[arg_pos:arg_pos+self.arg_counts[i]]
masked_out = masker(mask[mask_pos:mask_pos+shapes[i][1]], *masker_args)
if num_rows > 1 and (shapes[i][0] == 1 or shapes[i][0] is None):
masked_out = tuple([m[0] for _ in range(num_rows)] for m in masked_out)
masked.extend(masked_out)
mask_pos += shapes[i][1]
arg_pos += self.arg_counts[i]
return tuple(masked)
def joint_clustering(self, *args):
"""Return a joint clustering that merges the clusterings of all the submaskers."""
single_clustering = []
arg_pos = 0
for i, masker in enumerate(self.maskers):
masker_args = args[arg_pos:arg_pos+self.arg_counts[i]]
if callable(masker.clustering):
clustering = masker.clustering(*masker_args)
else:
clustering = masker.clustering
if len(single_clustering) == 0:
single_clustering = clustering
elif len(clustering) != 0:
raise NotImplementedError("Joining two non-trivial clusterings is not yet implemented in the Composite masker!")
return single_clustering