--- title: "Multilabel classification" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Multilabel classification} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include=FALSE} knitr::opts_chunk$set(echo = TRUE, eval = FALSE) ``` ## Intro First, we need to install ```blurr module``` for [Transformers](https://github.com/huggingface/transformers) integration. ``` reticulate::py_install('https://github.com/ohmeow/blurr',pip = TRUE) ``` ## Multilabel Grab data and take 1 % for fast training: ```{r} library(fastai) library(magrittr) library(zeallot) df = HF_load_dataset('civil_comments', split='train[:1%]') ``` ## Preprocess Select multiple outputs/columns: ```{r} df = data.table::as.data.table(df) lbl_cols = c('severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit') df <- df[,(lbl_cols) := round(.SD,0), .SDcols=lbl_cols] df <- df[, (lbl_cols) := lapply(.SD, as.integer), .SDcols=lbl_cols] ``` ## Pretrained model Load distill RoBERTa: ```{r} task = HF_TASKS_ALL()$SequenceClassification pretrained_model_name = "distilroberta-base" config = AutoConfig()$from_pretrained(pretrained_model_name) config$num_labels = length(lbl_cols) c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task, config=config) ``` ``` Downloading: 100%|██████████| 899k/899k [00:00<00:00, 961kB/s] Downloading: 100%|██████████| 456k/456k [00:00<00:00, 597kB/s] Downloading: 100%|██████████| 331M/331M [03:26<00:00, 1.61MB/s] ``` ## Datablock Create data blocks: ```{r} blocks = list( HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), MultiCategoryBlock(encoded=TRUE, vocab=lbl_cols) ) dblock = DataBlock(blocks=blocks, get_x=ColReader('text'), get_y=ColReader(lbl_cols), splitter=RandomSplitter()) dls = dblock %>% dataloaders(df, bs=8) dls %>% one_batch() ``` ``` [[1]] [[1]]$input_ids tensor([[ 0, 24268, 5257, ..., 1, 1, 1], [ 0, 287, 4505, ..., 1, 1, 1], [ 0, 38, 437, ..., 1, 1, 1], ..., [ 0, 152, 1129, ..., 1, 1, 1], [ 0, 85, 18, ..., 1, 1, 1], [ 0, 22014, 31, ..., 1, 1, 1]], device='cuda:0') [[1]]$attention_mask tensor([[1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], ..., [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0], [1, 1, 1, ..., 0, 0, 0]], device='cuda:0') [[2]] TensorMultiCategory([[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]], device='cuda:0') ``` ## Model ```{r} model = HF_BaseModelWrapper(hf_model) learn = Learner(dls, model, opt_func=partial(Adam), loss_func=BCEWithLogitsLossFlat(), metrics=partial(accuracy_multi(), thresh=0.2), cbs=HF_BaseModelCallback(), splitter=hf_splitter()) learn$loss_func$thresh = 0.2 learn$create_opt() # -> will create your layer groups based on your "splitter" function learn$freeze() learn %>% summary() ``` See summary: ``` epoch train_loss valid_loss accuracy_multi time ------ ----------- ----------- --------------- ------ HF_BaseModelWrapper (Input shape: 8 x 391) ================================================================ Layer (type) Output Shape Param # Trainable ================================================================ Embedding 8 x 391 x 768 38,603,520 False ________________________________________________________________ Embedding 8 x 391 x 768 394,752 False ________________________________________________________________ Embedding 8 x 391 x 768 768 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 768 590,592 True ________________________________________________________________ Dropout 8 x 768 0 False ________________________________________________________________ Linear 8 x 6 4,614 True ________________________________________________________________ Total params: 82,123,014 Total trainable params: 615,174 Total non-trainable params: 81,507,840 Optimizer used: functools.partial(.python_function at 0x7fee7e8166a8>) Loss function: FlattenedLoss of BCEWithLogitsLoss() Model frozen up to parameter group #2 Callbacks: - TrainEvalCallback - Recorder - ProgressCallback - HF_BaseModelCallback ``` ## Conclusion Finally, fit the model: ```{r} lrs = learn %>% lr_find(suggestions=TRUE) learn %>% fit_one_cycle(1, lr_max=1e-2) ``` ``` epoch train_loss valid_loss accuracy_multi time ------ ----------- ----------- --------------- ------ 0 0.040617 0.034286 0.993257 01:21 ``` Predict: ```{r} learn$loss_func$thresh = 0.02 learn %>% predict("Those damned affluent white people should only eat their own food, like cod cakes and boiled potatoes. No enchiladas for them!") ``` ``` $probabilities severe_toxicity obscene threat insult identity_attack sexual_explicit 1 9.302437e-07 0.004268706 0.0007849637 0.02687055 0.003282947 0.00232468 $labels [1] "insult" ```