mambular.base_models.basemodel 源代码
import torch
import torch.nn as nn
import os
import logging
[文档]class BaseModel(nn.Module):
def __init__(self, **kwargs):
"""
Initializes the BaseModel with given hyperparameters.
Parameters
----------
**kwargs : dict
Hyperparameters to be saved and used in the model.
"""
super(BaseModel, self).__init__()
self.hparams = kwargs
[文档] def save_hyperparameters(self, ignore=[]):
"""
Saves the hyperparameters while ignoring specified keys.
Parameters
----------
ignore : list, optional
List of keys to ignore while saving hyperparameters, by default [].
"""
self.hparams = {k: v for k, v in self.hparams.items() if k not in ignore}
for key, value in self.hparams.items():
setattr(self, key, value)
[文档] def save_model(self, path):
"""
Save the model parameters to the given path.
Parameters
----------
path : str
Path to save the model parameters.
"""
torch.save(self.state_dict(), path)
print(f"Model parameters saved to {path}")
[文档] def load_model(self, path, device="cpu"):
"""
Load the model parameters from the given path.
Parameters
----------
path : str
Path to load the model parameters from.
device : str, optional
Device to map the model parameters, by default 'cpu'.
"""
self.load_state_dict(torch.load(path, map_location=device))
self.to(device)
print(f"Model parameters loaded from {path}")
[文档] def count_parameters(self):
"""
Count the number of trainable parameters in the model.
Returns
-------
int
Total number of trainable parameters.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
[文档] def freeze_parameters(self):
"""
Freeze the model parameters by setting `requires_grad` to False.
"""
for param in self.parameters():
param.requires_grad = False
print("All model parameters have been frozen.")
[文档] def unfreeze_parameters(self):
"""
Unfreeze the model parameters by setting `requires_grad` to True.
"""
for param in self.parameters():
param.requires_grad = True
print("All model parameters have been unfrozen.")
[文档] def log_parameters(self, logger=None):
"""
Log the hyperparameters and model parameters.
Parameters
----------
logger : logging.Logger, optional
Logger instance to log the parameters, by default None.
"""
if logger is None:
logger = logging.getLogger(__name__)
logger.info("Hyperparameters:")
for key, value in self.hparams.items():
logger.info(f" {key}: {value}")
logger.info(f"Total number of trainable parameters: {self.count_parameters()}")
[文档] def parameter_count(self):
"""
Get a dictionary of parameter counts for each layer in the model.
Returns
-------
dict
Dictionary where keys are layer names and values are parameter counts.
"""
param_count = {}
for name, param in self.named_parameters():
param_count[name] = param.numel()
return param_count
[文档] def get_device(self):
"""
Get the device on which the model is located.
Returns
-------
torch.device
Device on which the model is located.
"""
return next(self.parameters()).device
[文档] def to_device(self, device):
"""
Move the model to the specified device.
Parameters
----------
device : torch.device or str
Device to move the model to.
"""
self.to(device)
print(f"Model moved to {device}")
[文档] def print_summary(self):
"""
Print a summary of the model, including the architecture and parameter counts.
"""
print(self)
print(f"\nTotal number of trainable parameters: {self.count_parameters()}")
print("\nParameter counts by layer:")
for name, count in self.parameter_count().items():
print(f" {name}: {count}")