! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
协同过滤
::: {#cell-2 .cell 0=‘默’ 1=‘认’ 2=’_’ 3=‘e’ 4=‘x’ 5=‘p’ 6=’ ’ 7=‘协’ 8=‘作’}
#默认班级等级 3
:::
from __future__ import annotations
from fastai.tabular.all import *
from nbdev.showdoc import *
快速获取数据并训练适合协同过滤的模型的工具
该模块包含了您在协同过滤应用中所需的所有高级功能,以便汇总数据、获取模型并使用Learner
对其进行训练。我们将依次介绍这些内容,但您也可以查看协同过滤教程。
收集数据
class TabularCollab(TabularPandas):
"Instance of `TabularPandas` suitable for collaborative filtering (with no continuous variable)"
=False with_cont
这只是为了使用表格应用的内部功能,不用担心。
class CollabDataLoaders(DataLoaders):
"Base `DataLoaders` for collaborative filtering."
@delegates(DataLoaders.from_dblock)
@classmethod
def from_df(cls, ratings, valid_pct=0.2, user_name=None, item_name=None, rating_name=None, seed=None, path='.', **kwargs):
"Create a `DataLoaders` suitable for collaborative filtering from `ratings`."
= ifnone(user_name, ratings.columns[0])
user_name = ifnone(item_name, ratings.columns[1])
item_name = ifnone(rating_name, ratings.columns[2])
rating_name = [user_name,item_name]
cat_names = RandomSplitter(valid_pct=valid_pct, seed=seed)(range_of(ratings))
splits = TabularCollab(ratings, [Categorify], cat_names, y_names=[rating_name], y_block=TransformBlock(), splits=splits)
to return to.dataloaders(path=path, **kwargs)
@classmethod
def from_csv(cls, csv, **kwargs):
"Create a `DataLoaders` suitable for collaborative filtering from `csv`."
return cls.from_df(pd.read_csv(csv), **kwargs)
= delegates(to=CollabDataLoaders.from_df)(CollabDataLoaders.from_csv) CollabDataLoaders.from_csv
这个类不应该直接使用,而应该优先考虑使用其中一个工厂方法。所有这些工厂方法都接受以下参数:
valid_pct
:用于验证的数据集随机百分比(可选的seed
)user_name
:包含用户的列名(默认为第一列)item_name
:包含项目的列名(默认为第二列)rating_name
:包含评分的列名(默认为第三列)path
:工作目录bs
:批大小val_bs
:验证DataLoader
的批大小(默认为bs
)shuffle_train
:是否对训练DataLoader
进行洗牌device
:要使用的 PyTorch 设备(默认为default_device()
)
show_doc(CollabDataLoaders.from_df)
CollabDataLoaders.from_df
[source]
CollabDataLoaders.from_df
(ratings
,valid_pct
=0.2
,user_name
=None
,item_name
=None
,rating_name
=None
,seed
=None
,path
='.'
,bs
:int
=64
,val_bs
:int
=None
,shuffle
:bool
=True
,device
=None
)
Create a DataLoaders
suitable for collaborative filtering from ratings
.
让我们通过一个例子看看这是怎么工作的:
= untar_data(URLs.ML_SAMPLE)
path = pd.read_csv(path/'ratings.csv')
ratings ratings.head()
userId | movieId | rating | timestamp | |
---|---|---|---|---|
0 | 73 | 1097 | 4.0 | 1255504951 |
1 | 561 | 924 | 3.5 | 1172695223 |
2 | 157 | 260 | 3.5 | 1291598691 |
3 | 358 | 1210 | 5.0 | 957481884 |
4 | 130 | 316 | 2.0 | 1138999234 |
= CollabDataLoaders.from_df(ratings, bs=64)
dls dls.show_batch()
userId | movieId | rating | |
---|---|---|---|
0 | 580 | 736 | 2.0 |
1 | 509 | 356 | 4.0 |
2 | 105 | 480 | 3.0 |
3 | 518 | 595 | 5.0 |
4 | 111 | 527 | 4.0 |
5 | 384 | 589 | 5.0 |
6 | 607 | 2918 | 3.5 |
7 | 460 | 1291 | 4.0 |
8 | 268 | 1270 | 5.0 |
9 | 56 | 586 | 4.0 |
show_doc(CollabDataLoaders.from_csv)
CollabDataLoaders.from_csv
[source]
CollabDataLoaders.from_csv
(csv
,valid_pct
=0.2
,user_name
=None
,item_name
=None
,rating_name
=None
,seed
=None
,path
='.'
,bs
:int
=64
,val_bs
:int
=None
,shuffle
:bool
=True
,device
=None
)
Create a DataLoaders
suitable for collaborative filtering from csv
.
= CollabDataLoaders.from_csv(path/'ratings.csv', bs=64) dls
模型
fastai提供两种用于协同过滤的模型:点积模型和神经网络。
class EmbeddingDotBias(Module):
"Base dot model for collaborative filtering."
def __init__(self, n_factors, n_users, n_items, y_range=None):
self.y_range = y_range
self.u_weight, self.i_weight, self.u_bias, self.i_bias) = [Embedding(*o) for o in [
(1), (n_items,1)
(n_users, n_factors), (n_items, n_factors), (n_users,
]]
def forward(self, x):
= x[:,0],x[:,1]
users,items = self.u_weight(users)* self.i_weight(items)
dot = dot.sum(1) + self.u_bias(users).squeeze() + self.i_bias(items).squeeze()
res if self.y_range is None: return res
return torch.sigmoid(res) * (self.y_range[1]-self.y_range[0]) + self.y_range[0]
@classmethod
def from_classes(cls, n_factors, classes, user=None, item=None, y_range=None):
"Build a model with `n_factors` by inferring `n_users` and `n_items` from `classes`"
if user is None: user = list(classes.keys())[0]
if item is None: item = list(classes.keys())[1]
= cls(n_factors, len(classes[user]), len(classes[item]), y_range=y_range)
res = classes,user,item
res.classes,res.user,res.item return res
def _get_idx(self, arr, is_item=True):
"Fetch item or user (based on `is_item`) for all in `arr`"
assert hasattr(self, 'classes'), "Build your model with `EmbeddingDotBias.from_classes` to use this functionality."
= self.classes[self.item] if is_item else self.classes[self.user]
classes = {v:k for k,v in enumerate(classes)}
c2i try: return tensor([c2i[o] for o in arr])
except KeyError as e:
= f"You're trying to access {'an item' if is_item else 'a user'} that isn't in the training data. If it was in your original data, it may have been split such that it's only in the validation set now."
message raise modify_exception(e, message, replace=True)
def bias(self, arr, is_item=True):
"Bias for item or user (based on `is_item`) for all in `arr`"
= self._get_idx(arr, is_item)
idx = (self.i_bias if is_item else self.u_bias).eval().cpu()
layer return to_detach(layer(idx).squeeze(),gather=False)
def weight(self, arr, is_item=True):
"Weight for item or user (based on `is_item`) for all in `arr`"
= self._get_idx(arr, is_item)
idx = (self.i_weight if is_item else self.u_weight).eval().cpu()
layer return to_detach(layer(idx),gather=False)
模型是用 n_factors
(内部向量的长度)、n_users
和 n_items
构建的。对于给定的用户和项目,它获取相应的权重和偏差,并返回
+ user_b + item_b torch.dot(user_w, item_w)
可选地,如果传入 y_range
,则对该结果应用 SigmoidRange
。
= dls.one_batch()
x,y = EmbeddingDotBias(50, len(dls.classes['userId']), len(dls.classes['movieId']), y_range=(0,5)
model
).to(x.device)= model(x)
out assert (0 <= out).all() and (out <= 5).all()
show_doc(EmbeddingDotBias.from_classes)
EmbeddingDotBias.from_classes
[source]
EmbeddingDotBias.from_classes
(n_factors
,classes
,user
=None
,item
=None
,y_range
=None
)
Build a model with n_factors
by inferring n_users
and n_items
from classes
y_range
被传递给主初始化。 user
和 item
是 classes
中用户和项目的键名(分别默认为第一个和第二个键)。 classes
预期是一个字典,键对应类别列表,类似于 CollabDataLoaders
中 dls.classes
的结果:
dls.classes
{'userId': ['#na#', 15, 17, 19, 23, 30, 48, 56, 73, 77, 78, 88, 95, 102, 105, 111, 119, 128, 130, 134, 150, 157, 165, 176, 187, 195, 199, 212, 213, 220, 232, 239, 242, 243, 247, 262, 268, 285, 292, 294, 299, 306, 311, 312, 313, 346, 353, 355, 358, 380, 382, 384, 387, 388, 402, 405, 407, 423, 427, 430, 431, 439, 452, 457, 460, 461, 463, 468, 472, 475, 480, 481, 505, 509, 514, 518, 529, 534, 537, 544, 547, 561, 564, 574, 575, 577, 580, 585, 587, 596, 598, 605, 607, 608, 615, 624, 648, 652, 654, 664, 665],
'movieId': ['#na#', 1, 10, 32, 34, 39, 47, 50, 110, 150, 153, 165, 231, 253, 260, 293, 296, 316, 318, 344, 356, 357, 364, 367, 377, 380, 457, 480, 500, 527, 539, 541, 586, 587, 588, 589, 590, 592, 593, 595, 597, 608, 648, 733, 736, 778, 780, 858, 924, 1036, 1073, 1089, 1097, 1136, 1193, 1196, 1197, 1198, 1200, 1206, 1210, 1213, 1214, 1221, 1240, 1265, 1270, 1291, 1580, 1617, 1682, 1704, 1721, 1732, 1923, 2028, 2396, 2571, 2628, 2716, 2762, 2858, 2918, 2959, 2997, 3114, 3578, 3793, 4226, 4306, 4886, 4963, 4973, 4993, 5349, 5952, 6377, 6539, 7153, 8961, 58559]}
让我们看看它如何在实践中使用:
= EmbeddingDotBias.from_classes(50, dls.classes, y_range=(0,5)
model
).to(x.device)= model(x)
out assert (0 <= out).all() and (out <= 5).all()
当使用 EmbeddingDotBias.from_classes
创建模型时,添加了两个便利方法以便于访问权重和偏差:
show_doc(EmbeddingDotBias.weight)
EmbeddingDotBias.weight
[source]
EmbeddingDotBias.weight
(arr
,is_item
=True
)
Weight for item or user (based on is_item
) for all in arr
arr
的元素应该是类名(这就是为什么模型需要使用 EmbeddingDotBias.from_classes
创建的原因)。
= dls.classes['movieId'][42]
mov = model.weight([mov])
w 42]))) test_eq(w, model.i_weight(tensor([
show_doc(EmbeddingDotBias.bias)
EmbeddingDotBias.bias
[source]
EmbeddingDotBias.bias
(arr
,is_item
=True
)
Bias for item or user (based on is_item
) for all in arr
arr
的元素预计是类名(这就是为什么模型需要使用 EmbeddingDotBias.from_classes
创建的原因)。
= dls.classes['movieId'][42]
mov = model.bias([mov])
b 42]))) test_eq(b, model.i_bias(tensor([
::: {#cell-35 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class EmbeddingNN(TabularModel):
"Subclass `TabularModel` to create a NN suitable for collaborative filtering."
@delegates(TabularModel.__init__)
def __init__(self, emb_szs, layers, **kwargs):
super().__init__(emb_szs=emb_szs, n_cont=0, out_sz=1, layers=layers, **kwargs)
:::
show_doc(EmbeddingNN)
class
EmbeddingNN
[source]
EmbeddingNN
(emb_szs
,layers
,ps
=None
,embed_p
=0.0
,y_range
=None
,use_bn
=True
,bn_final
=False
,bn_cont
=True
) ::TabularModel
Subclass TabularModel
to create a NN suitable for collaborative filtering.
emb_szs
应该是一个包含两个元组的列表,一个用于用户,一个用于项目,每个元组包含用户/项目的数量和相应的嵌入大小(函数 get_emb_sz
可以提供一个好的默认值)。所有其他参数将传递给 TabularModel
。
= get_emb_sz(dls.train_ds, {})
emb_szs = EmbeddingNN(emb_szs, [50], y_range=(0,5)
model
).to(x.device)= model(x)
out assert (0 <= out).all() and (out <= 5).all()
创建一个Learner
以下函数使我们能够快速从数据中创建一个用于协同过滤的 Learner
。
@delegates(Learner.__init__)
def collab_learner(dls, n_factors=50, use_nn=False, emb_szs=None, layers=None, config=None, y_range=None, loss_func=None, **kwargs):
"Create a Learner for collaborative filtering on `dls`."
= get_emb_sz(dls, ifnone(emb_szs, {}))
emb_szs if loss_func is None: loss_func = MSELossFlat()
if config is None: config = tabular_config()
if y_range is not None: config['y_range'] = y_range
if layers is None: layers = [n_factors]
if use_nn: model = EmbeddingNN(emb_szs=emb_szs, layers=layers, **config)
else: model = EmbeddingDotBias.from_classes(n_factors, dls.classes, y_range=y_range)
return Learner(dls, model, loss_func=loss_func, **kwargs)
如果use_nn=False
,则使用的模型为EmbeddingDotBias
,包含n_factors
和y_range
。否则,使用EmbeddingNN
,您可以传递emb_szs
(如果您不提供,将通过get_emb_sz
从dls
中推断出),layers
(默认为[n_factors]
)y_range
,以及您可以使用tabular_config
创建的config
来自定义您的模型。
loss_func
默认为MSELossFlat
,所有其他参数将传递给Learner
。
= collab_learner(dls, y_range=(0,5)) learn
1) learn.fit_one_cycle(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 2.521979 | 2.541627 | 00:00 |
导出 -
from nbdev import *
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 36_text.models.qrnn.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.cutmix.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted index.ipynb.
Converted tutorial.ipynb.