Transformers Integration Classes

Classification

class small_text.integrations.transformers.classifiers.classification.FineTuningArguments(base_lr, layerwise_gradient_decay, gradual_unfreezing=- 1, cut_fraction=0.1)[source]

Arguments to enable and configure gradual unfreezing and discriminative learning rates as used in Universal Language Model Fine-tuning (ULMFiT) [HR18].

class small_text.integrations.transformers.classifiers.classification.TransformerBasedClassification(transformer_model, num_classes, multi_label=False, num_epochs=10, lr=2e-05, mini_batch_size=12, validation_set_size=0.1, validations_per_epoch=1, early_stopping_no_improvement=5, early_stopping_acc=- 1, model_selection=True, fine_tuning_arguments=None, device=None, memory_fix=1, class_weight=None, verbosity=20, cache_dir='.active_learning_lib_cache/')[source]
fit(train_set, validation_set=None, optimizer=None, scheduler=None)

Trains the model using the given train set.

Parameters
  • train_set (TransformersDataset) – Training set.

  • validation_set (TransformersDataset, default=None) – A validation set used for validation during training, or None. If None, the fit operation will split apart a subset of the trainset as a validation set, whose size is set by self.validation_set_size.

  • optimizer (torch.optim.optimizer.Optimizer or None, default=None) – A pytorch optimizer.

  • scheduler (torch.optim._LRScheduler or None, default=None) – A pytorch scheduler.

Returns

self – Returns the current classifier with a fitted model.

Return type

TransformerBasedClassification

predict(data_set, return_proba=False)
Parameters
  • data_set (small_text.integrations.transformers.TransformerDataset) – A dataset on whose instances predictions are made.

  • return_proba (bool, default=False) – If True, additionally returns the confidence distribution over all classes.

Returns

  • predictions (np.ndarray[np.int32] or csr_matrix[np.int32]) – List of predictions if the classifier was fitted on single-label data, otherwise a sparse matrix of predictions.

  • probas (np.ndarray[np.float32], optional) – List of probabilities (or confidence estimates) if return_proba is True.

predict_proba(test_set)
Parameters

test_set (small_text.integrations.pytorch.PytorchTextClassificationDataset) – Test set.

Returns

scores – Distribution of confidence scores over all classes.

Return type

np.ndarray