Transformers Integration Classes
Classification
- class small_text.integrations.transformers.classifiers.classification.FineTuningArguments(base_lr, layerwise_gradient_decay, gradual_unfreezing=- 1, cut_fraction=0.1)[source]
Arguments to enable and configure gradual unfreezing and discriminative learning rates as used in Universal Language Model Fine-tuning (ULMFiT) [HR18].
- class small_text.integrations.transformers.classifiers.classification.TransformerBasedClassification(transformer_model, num_classes, multi_label=False, num_epochs=10, lr=2e-05, mini_batch_size=12, validation_set_size=0.1, validations_per_epoch=1, early_stopping_no_improvement=5, early_stopping_acc=- 1, model_selection=True, fine_tuning_arguments=None, device=None, memory_fix=1, class_weight=None, verbosity=20, cache_dir='.active_learning_lib_cache/')[source]
- fit(train_set, validation_set=None, optimizer=None, scheduler=None)
Trains the model using the given train set.
- Parameters
train_set (TransformersDataset) – Training set.
validation_set (TransformersDataset, default=None) – A validation set used for validation during training, or None. If None, the fit operation will split apart a subset of the trainset as a validation set, whose size is set by self.validation_set_size.
optimizer (torch.optim.optimizer.Optimizer or None, default=None) – A pytorch optimizer.
scheduler (torch.optim._LRScheduler or None, default=None) – A pytorch scheduler.
- Returns
self – Returns the current classifier with a fitted model.
- Return type
- predict(data_set, return_proba=False)
- Parameters
data_set (small_text.integrations.transformers.TransformerDataset) – A dataset on whose instances predictions are made.
return_proba (bool, default=False) – If True, additionally returns the confidence distribution over all classes.
- Returns
predictions (np.ndarray[np.int32] or csr_matrix[np.int32]) – List of predictions if the classifier was fitted on single-label data, otherwise a sparse matrix of predictions.
probas (np.ndarray[np.float32], optional) – List of probabilities (or confidence estimates) if return_proba is True.
- predict_proba(test_set)
- Parameters
test_set (small_text.integrations.pytorch.PytorchTextClassificationDataset) – Test set.
- Returns
scores – Distribution of confidence scores over all classes.
- Return type
np.ndarray