Transformers Integration
The Transformers Integration makes transformer-based classification and sentence transformer finetuning usable in small-text. It relies on the Pytorch Integration which is a prerequisite.
Note
Some implementation make use of optional dependencies.
Overview
Installation
Before you can use the transformers integration make sure the required dependencies have been installed.
Contents
With the integration you will have access to the following additional components:
Components |
Resources |
---|---|
Datasets |
|
Classifiers |
|
Query Strategies |
(See Query Strategies) |
Compatible Models
While this integration is tailored to the transformers library, but since models (and their corresponding) tokenizers can vary considerably, not all models are applicable for small-text classifiers. To help you with finding a suitable model, we list a subset of compatible models in the following which you can use as a starting point:
Size |
Models |
---|---|
< 1B Parameters |
BERT, T5, DistilRoBERTa, DistilBERT, ELECTRA, BioGPT |
English Models
BERT models: bert-base-uncased, bert-large-uncased, bert-base-uncased
DistilRoBERTa: distilroberta-base
DistilBERT: distilbert-base-uncased, distilroberta-base
ELECTRA: google/electra-base-discriminator, google/electra-small-discriminator
BioGPT: microsoft/biogpt
This list is not exhaustive. Let us know when you have tested other models that might belong on these lists.
TransformerBasedClassification: Extended Functionality
Layer-specific Fine-tuning
Layer-specific fine-tuning can be enabled by setting FineTuningArguments
during the construction of TransformerBasedClassification
. With this, you can enable layerwise gradient decay and gradual unfreezing:
Layerwise gradient decay: learning rates decrease the lower the layer’s level is.
Gradual unfreezing: lower layers are frozen at the start of the training and become gradually unfrozen with each epoch.
See [HR18] for more details on these methods.
Examples
Transformer-based Classification
An example is provided in examples/examplecode/transformers_multiclass_classification.py
:
"""Example of a transformer-based active learning multi-class text classification.
"""
import numpy as np
from transformers import AutoTokenizer
from small_text import (
EmptyPoolException,
PoolBasedActiveLearner,
PoolExhaustedException,
RandomSampling,
TransformerBasedClassificationFactory,
TransformerModelArguments,
random_initialization_balanced
)
from examplecode.data.corpus_twenty_news import get_twenty_newsgroups_corpus
from examplecode.data.example_data_transformers import preprocess_data
from examplecode.shared import evaluate
TRANSFORMER_MODEL = TransformerModelArguments('distilroberta-base')
TWENTY_NEWS_SUBCATEGORIES = ['rec.sport.baseball', 'sci.med', 'rec.autos']
def main(num_iterations=10):
# Active learning parameters
num_classes = len(TWENTY_NEWS_SUBCATEGORIES)
clf_factory = TransformerBasedClassificationFactory(TRANSFORMER_MODEL,
num_classes,
kwargs=dict({
'device': 'cuda'
}))
query_strategy = RandomSampling()
# Prepare some data
train, test = get_twenty_newsgroups_corpus(categories=TWENTY_NEWS_SUBCATEGORIES)
tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL.model, cache_dir='.cache/')
train = preprocess_data(tokenizer, train.data, train.target)
test = preprocess_data(tokenizer, test.data, test.target)
# Active learner
active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, train)
indices_labeled = initialize_active_learner(active_learner, train.y)
try:
perform_active_learning(active_learner, train, indices_labeled, test, num_iterations)
except PoolExhaustedException:
print('Error! Not enough samples left to handle the query.')
except EmptyPoolException:
print('Error! No more samples left. (Unlabeled pool is empty)')
def perform_active_learning(active_learner, train, indices_labeled, test, num_iterations):
# Perform 10 iterations of active learning...
for i in range(num_iterations):
# ...where each iteration consists of labelling 20 samples
indices_queried = active_learner.query(num_samples=20)
# Simulate user interaction here. Replace this for real-world usage.
y = train.y[indices_queried]
# Return the labels for the current query to the active learner.
active_learner.update(y)
indices_labeled = np.concatenate([indices_queried, indices_labeled])
print('Iteration #{:d} ({} samples)'.format(i, len(indices_labeled)))
evaluate(active_learner, train[indices_labeled], test)
def initialize_active_learner(active_learner, y_train):
indices_initial = random_initialization_balanced(y_train)
y_initial = np.array([y_train[i] for i in indices_initial])
active_learner.initialize_data(indices_initial, y_initial)
return indices_initial
if __name__ == '__main__':
import argparse
import logging
logging.getLogger('small_text').setLevel(logging.INFO)
parser = argparse.ArgumentParser(description='An example that shows active learning '
'for multi-class text classification '
'using transformers.')
parser.add_argument('--num_iterations', type=int, default=10,
help='number of active learning iterations')
args = parser.parse_args()
main(num_iterations=args.num_iterations)
Sentence Transformer Finetuning
An example is provided in examples/examplecode/setfit_multiclass_classification.py
:
"""Example of a setfit-based active learning multi-class text classification.
"""
import numpy as np
from small_text import (
EmptyPoolException,
PoolBasedActiveLearner,
PoolExhaustedException,
BreakingTies,
SetFitClassificationFactory,
SetFitModelArguments,
TextDataset,
random_initialization_balanced
)
from examplecode.data.corpus_twenty_news import get_twenty_newsgroups_corpus
from examplecode.shared import evaluate
TWENTY_NEWS_SUBCATEGORIES = ['rec.sport.baseball', 'sci.med', 'rec.autos']
def main(num_iterations=10):
# Active learning parameters
num_classes = len(TWENTY_NEWS_SUBCATEGORIES)
model_args = SetFitModelArguments('sentence-transformers/paraphrase-mpnet-base-v2')
# If GPU memory is a problem:
# model_args = SetFitModelArguments('sentence-transformers/all-MiniLM-L6-v2')
clf_factory = SetFitClassificationFactory(model_args,
num_classes,
classification_kwargs=dict({
'device': 'cuda',
'max_seq_len': 64,
'mini_batch_size': 8
}))
query_strategy = BreakingTies()
# Prepare some data
train, test = get_twenty_newsgroups_corpus(categories=TWENTY_NEWS_SUBCATEGORIES)
train.data = [txt for txt in train.data]
train = TextDataset.from_arrays(train.data, train.target, target_labels=np.arange(num_classes))
test = TextDataset(test.data, test.target, target_labels=np.arange(num_classes))
# Active learner
setfit_train_kwargs = {'show_progress_bar': False}
active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, train,
fit_kwargs={'setfit_train_kwargs': setfit_train_kwargs})
indices_labeled = initialize_active_learner(active_learner, train.y)
try:
perform_active_learning(active_learner, train, indices_labeled, test, num_iterations)
except PoolExhaustedException:
print('Error! Not enough samples left to handle the query.')
except EmptyPoolException:
print('Error! No more samples left. (Unlabeled pool is empty)')
def perform_active_learning(active_learner, train, indices_labeled, test, num_iterations):
# Perform 10 iterations of active learning...
for i in range(num_iterations):
# ...where each iteration consists of labelling 20 samples
indices_queried = active_learner.query(num_samples=20)
# Simulate user interaction here. Replace this for real-world usage.
y = train.y[indices_queried]
# Return the labels for the current query to the active learner.
active_learner.update(y)
indices_labeled = np.concatenate([indices_queried, indices_labeled])
print('Iteration #{:d} ({} samples)'.format(i, len(indices_labeled)))
evaluate(active_learner, train[indices_labeled], test)
def initialize_active_learner(active_learner, y_train):
indices_initial = random_initialization_balanced(y_train)
y_initial = np.array([y_train[i] for i in indices_initial])
active_learner.initialize_data(indices_initial, y_initial)
return indices_initial
if __name__ == '__main__':
import argparse
import logging
logging.getLogger('small_text').setLevel(logging.INFO)
for logger_name in ['setfit.modeling', 'setfit.trainer']:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.ERROR)
parser = argparse.ArgumentParser(description='An example that shows active learning '
'for multi-class text classification '
'using a setfit classifier.')
parser.add_argument('--num_iterations', type=int, default=10,
help='number of active learning iterations')
args = parser.parse_args()
main(num_iterations=args.num_iterations)