Training
Quickstart
Installation
To install GLiClass, run the following command:
pip install gliclass
Base Training Script
Load pretrained model
import torch
from gliclass import GLiClassModel
from transformers import AutoTokenizer
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
model_name = "knowledgator/gliclass-small-v1.0"
model = GLiClassModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
Define Dataset
The dataset for training GLiClass must contain text field, which represents text for classification, all_labels field, whic stands for all labels to classify from and true_labels which will represent the correct labels for given text. For more info about datasets please visit our datasets page
from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding
data = [
{
"text": "A new machine learning platform automates complex data workflows but faces integration issues.",
"all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
"true_labels": ["AI", "integration", "automation"]
}
]
train_dataset = GLiClassDataset(
data,
tokenizer,
max_length= 1024,
problem_type= 'multi_label_classification',
)
# Data collator
data_collator = DataCollatorWithPadding(device=device)
Expected Output
Total labels: 5
Define functions for metrics computation
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
def compute_metrics_multi_label(p):
predictions, labels = p
labels = labels.reshape(-1)
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels>0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
Train the model
from gliclass.training import TrainingArguments, Trainer
# Training arguments
training_args = TrainingArguments(
output_dir="my-awesome-gliclass-model",
learning_rate=1e-5,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type="cosine",
per_device_eval_batch_size=8,
num_train_epochs=1,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=False, # Set to True if you want to use mixed precision training
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics_multi_label
)
# Run training
trainer.train()
Full Base Training Script [source]
The following script could be used both for training from scratch and fine-tuning pretrained model:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import numpy as np
import argparse
import json
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import AutoTokenizer, AutoConfig
import random
import torch
from gliclass import GLiClassModelConfig, GLiClassModel
from gliclass.training import TrainingArguments, Trainer
from gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset
def compute_metrics(p):
predictions, labels = p
labels = labels.reshape(-1)
if args.problem_type == 'single_label_classification':
preds = np.argmax(predictions, axis=1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
elif args.problem_type == 'multi_label_classification':
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels>0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
else:
raise NotImplementedError(f"{args.problem_type} is not implemented.")
def main(args):
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
if args.model_name is not None:
model = GLiClassModel.from_pretrained(args.model_name, focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
else:
tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)
encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)
if args.label_model_name is not None:
label_model_config = AutoConfig.from_pretrained(args.label_model_name)
glicalss_config = GLiClassModelConfig(
encoder_config=encoder_config,
encoder_model=args.encoder_model_name,
label_model_name=args.label_model_name,
label_model_config=label_model_config,
class_token_index=len(tokenizer),
text_token_index=len(tokenizer)+1,
pooling_strategy=args.pooler_type,
scorer_type=args.scorer_type,
use_lstm=args.use_lstm,
focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma,
contrastive_loss_coef=args.contrastive_loss_coef,
normalize_features=args.normalize_features,
extract_text_features=args.extract_text_features,
architecture_type=args.architecture_type,
prompt_first=args.prompt_first,
squeeze_layers=args.squeeze_layers,
shuffle_labels=args.shuffle_labels
)
model = GLiClassModel(glicalss_config, from_pretrained=True)
if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder'}:
new_words = ["<<LABEL>>", "<<SEP>>"]
tokenizer.add_tokens(new_words, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
model.to(device)
if model.config.label_model_name is not None:
labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)
else:
labels_tokenizer = None
model.config.problem_type = args.problem_type
with open(args.data_path, 'r') as f:
data = json.load(f)
print('Dataset size:', len(data))
random.shuffle(data)
print('Dataset is shuffled...')
train_data = data[:int(len(data)*0.9)]
test_data = data[int(len(data)*0.9):]
print('Dataset is splitted...')
train_dataset = GLiClassDataset(train_data, tokenizer, args.max_length,
args.problem_type, args.architecture_type,
args.prompt_first, labels_tokenizer=labels_tokenizer)
test_dataset = GLiClassDataset(test_data, tokenizer, args.max_length, args.problem_type,
args.architecture_type, args.prompt_first,
labels_tokenizer = labels_tokenizer)
data_collator = DataCollatorWithPadding(device=device)
training_args = TrainingArguments(
output_dir=args.save_path,
learning_rate=args.encoder_lr,
weight_decay=args.encoder_weight_decay,
others_lr=args.others_lr,
others_weight_decay=args.others_weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
warmup_ratio=args.warmup_ratio,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.num_epochs,
evaluation_strategy="epoch",
save_steps = args.save_steps,
save_total_limit=args.save_total_limit,
dataloader_num_workers = args.num_workers,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=args.fp16,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default= None)
parser.add_argument('--encoder_model_name', type=str, default = 'microsoft/deberta-v3-small')
parser.add_argument('--label_model_name', type=str, default = "BAAI/bge-small-en-v1.5")
parser.add_argument('--save_path', type=str, default = 'models/')
parser.add_argument('--data_path', type=str, default = 'data/zero-cats.json')
parser.add_argument('--problem_type', type=str, default='multi_label_classification')
parser.add_argument('--pooler_type', type=str, default='avg')
parser.add_argument('--scorer_type', type=str, default='simple')
parser.add_argument('--architecture_type', type=str, default='uni-encoder')
parser.add_argument('--normalize_features', type=bool, default=False)
parser.add_argument('--extract_text_features', type=bool, default=False)
parser.add_argument('--prompt_first', type=bool, default=True)
parser.add_argument('--use_lstm', type=bool, default=False)
parser.add_argument('--squeeze_layers', type=bool, default=False)
parser.add_argument('--shuffle_labels', type=bool, default=True)
parser.add_argument('--num_epochs', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--encoder_lr', type=float, default=1e-5)
parser.add_argument('--others_lr', type=float, default=3e-5)
parser.add_argument('--encoder_weight_decay', type=float, default=0.01)
parser.add_argument('--others_weight_decay', type=float, default=0.01)
parser.add_argument('--warmup_ratio', type=float, default=0.05)
parser.add_argument('--lr_scheduler_type', type=str, default='linear')
parser.add_argument('--focal_loss_alpha', type=float, default=-1)
parser.add_argument('--focal_loss_gamma', type=float, default=-1)
parser.add_argument('--contrastive_loss_coef', type=float, default=0.)
parser.add_argument('--max_length', type=int, default=1024)
parser.add_argument('--save_steps', type=int, default=1000)
parser.add_argument('--save_total_limit', type=int, default=3)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--fp16', type=bool, default=False)
args = parser.parse_args()
main(args)
RL Training Script
The GLiClass framework also supports Reinforcement learning, you can start training models using it with just a couple of changes to your training script.
Load pretrained model
This step leaves unchanged
import torch
from gliclass import GLiClassModel
from transformers import AutoTokenizer
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
model_name = "knowledgator/gliclass-small-v1.0"
model = GLiClassModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
Initialize RL training components
from transformers import AutoModelForSequenceClassification
from gliclass.pipeline import ZeroShotClassificationPipeline
# Value model for advantage estimation
value_model = AutoModelForSequenceClassification.from_pretrained(model.config.encoder_model_name, num_labels=1)
value_model.resize_token_embeddings(len(tokenizer))
# Reference model for baseline comparisons
refrence_model = GLiClassModel.from_pretrained(model_name) # for most cases you may use the same model as reference model
reference_tokenizer = AutoTokenizer.from_pretrained(model_name)
reference_pipe = ZeroShotClassificationPipeline(refrence_model, reference_tokenizer,
classification_type='multi-label',
progress_bar=False, device=device)
Define Dataset
This step leaves unchanged
from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding
data = [
{
"text": "A new machine learning platform automates complex data workflows but faces integration issues.",
"all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
"true_labels": ["AI", "integration", "automation"]
}
]
train_dataset = GLiClassDataset(
data,
tokenizer,
max_length= 1024,
problem_type= 'multi_label_classification',
)
# Data collator
data_collator = DataCollatorWithPadding(device=device)
Expected Output
Total labels: 5
Define functions for metrics computation
This step leaves unchanged
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
def compute_metrics_multi_label(p):
predictions, labels = p
labels = labels.reshape(-1)
predictions = predictions.reshape(-1)
preds = (predictions > 0.5).astype(int)
labels = np.where(labels>0.5, 1, 0)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
accuracy = accuracy_score(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
Define reward function
def default_f1_reward(
probs: torch.Tensor,
actions: torch.Tensor,
original_targets: torch.Tensor,
valid_mask: torch.Tensor
) -> torch.Tensor:
"""
A variant that extracts list-of-indices sets and then calculates
the F1 score in a classical manner. Returns shape (N, 1).
Args:
probs: (N, T) Tensor of probabilities (not used here but left for interface consistency).
actions: (N, T) Tensor of predicted labels in {0, 1}.
original_targets: (N, T) Tensor of ground-truth labels in {0, 1}.
valid_mask: (N, T) Tensor indicating which positions are valid (1) vs. invalid (0).
Returns:
f1_scores: (N, 1) Tensor containing the F1 score for each row.
"""
N = actions.shape[0]
f1_scores = []
for i in range(N):
# Filter valid positions
valid_preds_i = actions[i] * valid_mask[i]
valid_targets_i = original_targets[i] * valid_mask[i]
# Get the set of indices where we predicted 1
predicted_set = set((valid_preds_i == 1).nonzero(as_tuple=True)[0].tolist())
# Get the set of indices where the ground truth is 1
target_set = set((valid_targets_i == 1).nonzero(as_tuple=True)[0].tolist())
# Compute intersection
intersection = predicted_set.intersection(target_set)
# Precision
if len(predicted_set) > 0:
precision = len(intersection) / len(predicted_set)
else:
precision = 0.0
# Recall
if len(target_set) > 0:
recall = len(intersection) / len(target_set)
else:
recall = 0.0
# F1 score
if (precision + recall) > 0:
f1 = 2 * precision * recall / (precision + recall)
else:
f1 = 0.0
f1_scores.append(f1)
# Convert list to tensor shape (N, 1)
f1_scores = torch.tensor(f1_scores, dtype=torch.float).unsqueeze(-1)
return f1_scores.detach().to(probs.device)
Train the model with RLTrainer
from gliclass.training import RLTrainerConfig, RLTrainer
training_args = RLTrainerConfig(
output_dir="my-awesome-rl-gliclass-model",
learning_rate=1e-5,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type="cosine",
per_device_eval_batch_size=8,
num_train_epochs=1,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=False,
cliprange=0.2,
num_rl_iters=2
)
trainer = RLTrainer(
model=model,
value_model=value_model,
reference_model=reference_pipe,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics_multi_label,
reward_components={
'micro_f1': default_f1_reward,
},
)
trainer.train()
To avoid AttributeError during run in notebooks add following lines after initializing trainer:
trainer = RLTrainer(
model=model,
...
)
from transformers.utils.notebook import NotebookProgressCallback
trainer.remove_callback(NotebookProgressCallback)
trainer.train()
Full RL Training Script [source]
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import numpy as np
import argparse
import json
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
import random
import torch
from gliclass import GLiClassModelConfig, GLiClassModel, ZeroShotClassificationPipeline
from gliclass.training import TrainingArguments, Trainer, RLTrainerConfig, RLTrainer
from gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset
from gliclass.utils import default_f1_reward
def accuracy_reward(probs, actions, targets, valid_mask):
probs = probs * valid_mask
predicts = torch.argmax(probs, dim=-1)
true_labels = torch.argmax(targets, dim=-1)
correct = (predicts == true_labels).float().unsqueeze(1)
return correct
def recall_reward(
probs: torch.Tensor,
actions: torch.Tensor,
original_targets: torch.Tensor,
valid_mask: torch.Tensor
) -> torch.Tensor:
valid_preds = actions * valid_mask
valid_targets = original_targets * valid_mask
TP = torch.sum((valid_preds * valid_targets), dim=-1)
FN = torch.sum(((1 - valid_preds) * valid_targets), dim=-1)
eps = 1e-8
recall = TP / (TP + FN + eps)
return recall.detach().unsqueeze(1)
def main(args):
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
if args.model_name is not None:
model = GLiClassModel.from_pretrained(args.model_name, focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
else:
tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)
encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)
if args.label_model_name is not None:
label_model_config = AutoConfig.from_pretrained(args.label_model_name)
glicalss_config = GLiClassModelConfig(
encoder_config=encoder_config,
encoder_model=args.encoder_model_name,
label_model_name=args.label_model_name,
label_model_config=label_model_config,
class_token_index=len(tokenizer),
text_token_index=len(tokenizer)+1,
pooling_strategy=args.pooler_type,
scorer_type=args.scorer_type,
use_lstm=args.use_lstm,
focal_loss_alpha=args.focal_loss_alpha,
focal_loss_gamma=args.focal_loss_gamma,
labels_smoothing=args.labels_smoothing,
entropy_beta=args.entropy_beta,
kl_beta=args.kl_beta,
contrastive_loss_coef=args.contrastive_loss_coef,
normalize_features=args.normalize_features,
extract_text_features=args.extract_text_features,
architecture_type=args.architecture_type,
prompt_first=args.prompt_first,
squeeze_layers=args.squeeze_layers
)
glicalss_config.problem_type = args.problem_type
model = GLiClassModel(glicalss_config, from_pretrained=True)
if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder'}:
new_words = ["<<LABEL>>", "<<SEP>>"]
tokenizer.add_tokens(new_words, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
if args.set_value_model:
value_model = AutoModelForSequenceClassification.from_pretrained(model.config.encoder_model_name, num_labels=1)
value_model.resize_token_embeddings(len(tokenizer))
else:
value_model = None
if args.reference_model is not None:
refrence_model = GLiClassModel.from_pretrained(args.reference_model)
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_model)
reference_pipe = ZeroShotClassificationPipeline(refrence_model, reference_tokenizer,
classification_type='multi-label',
progress_bar=False, device=device)
else:
reference_pipe = None
if args.label_model_name is not None:
labels_tokenizer = AutoTokenizer.from_pretrained(args.label_model_name)
else:
labels_tokenizer = None
model.to(device)
with open(args.data_path, 'r') as f:
data = json.load(f)[:]
init_ld = len(data)*1
print('Dataset size:', len(data))
random.shuffle(data)
print('Dataset is shuffled...')
train_data = data[:int(len(data)*0.9)]
test_data = data[int(len(data)*0.9):]
print('Dataset is splitted...')
train_dataset = GLiClassDataset(train_data, tokenizer, args.max_length,
args.problem_type, args.architecture_type,
args.prompt_first, labels_tokenizer=labels_tokenizer)
test_dataset = GLiClassDataset(test_data, tokenizer, args.max_length, args.problem_type,
args.architecture_type, args.prompt_first,
labels_tokenizer = labels_tokenizer)
data_collator = DataCollatorWithPadding(device=device)
training_args = RLTrainerConfig(
output_dir=args.save_path,
learning_rate=args.encoder_lr,
weight_decay=args.encoder_weight_decay,
others_lr=args.others_lr,
others_weight_decay=args.others_weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
warmup_ratio=args.warmup_ratio,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.num_epochs,
evaluation_strategy="epoch",
save_steps = args.save_steps,
save_total_limit=args.save_total_limit,
dataloader_num_workers = args.num_workers,
logging_steps=100,
use_cpu = False,
report_to="none",
fp16=args.fp16,
cliprange=args.clip_range,
num_rl_iters=args.num_rl_iters
)
trainer = RLTrainer(
model=model,
value_model=value_model,
reference_model=reference_pipe,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
reward_components={
'micro_f1': default_f1_reward,
},
)
trainer.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default= "knowledgator/gliclass-modern-base-v2.0-init")
parser.add_argument('--encoder_model_name', type=str, default = 'microsoft/deberta-v3-small')
parser.add_argument('--label_model_name', type=str, default = "BAAI/bge-small-en-v1.5")
parser.add_argument('--reference_model', type=str, default = None)
parser.add_argument('--set_value_model', type=bool, default = True)
parser.add_argument('--save_path', type=str, default = 'models/')
parser.add_argument('--data_path', type=str, default = 'data/zero-cats.json')
parser.add_argument('--problem_type', type=str, default='multi_label_classification')
parser.add_argument('--pooler_type', type=str, default='avg')
parser.add_argument('--scorer_type', type=str, default='simple')
parser.add_argument('--architecture_type', type=str, default='uni-encoder')
parser.add_argument('--normalize_features', type=bool, default=False)
parser.add_argument('--extract_text_features', type=bool, default=False)
parser.add_argument('--prompt_first', type=bool, default=True)
parser.add_argument('--use_lstm', type=bool, default=False)
parser.add_argument('--squeeze_layers', type=bool, default=False)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--encoder_lr', type=float, default=2e-6)
parser.add_argument('--others_lr', type=float, default=3e-6)
parser.add_argument('--encoder_weight_decay', type=float, default=0.01)
parser.add_argument('--others_weight_decay', type=float, default=0.01)
parser.add_argument('--warmup_ratio', type=float, default=0.05)
parser.add_argument('--lr_scheduler_type', type=str, default='linear')
parser.add_argument('--focal_loss_alpha', type=float, default=-1)
parser.add_argument('--focal_loss_gamma', type=float, default=-1)
parser.add_argument('--labels_smoothing', type=float, default=-1)
parser.add_argument('--entropy_beta', type=float, default=-1)
parser.add_argument('--kl_beta', type=float, default=0.1)
parser.add_argument('--clip_range', type=float, default=0.2)
parser.add_argument('--num_rl_iters', type=int, default=2)
parser.add_argument('--contrastive_loss_coef', type=float, default=0.)
parser.add_argument('--max_length', type=int, default=2048)
parser.add_argument('--save_steps', type=int, default=300)
parser.add_argument('--save_total_limit', type=int, default=3)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--fp16', type=bool, default=False)
args = parser.parse_args()
main(args)
Evaluation
Once you have trained your model, you will most likely want to evaluate it. We have already prepared a test_gliclass.py[source] script for you that will help you to evaluate the model on 13 different zero-shot datasets.
Enter the repo and activate yout env
cd GLiClass
source venv/bin/activate
If you don't have gliclass framework installed, please check out our installation guide first.
Run evaluation script
python test_gliclass.py --model knowledgator/gliclass-base-v1.0 --api_key YOR_KEY_IF_REQUIERED