--- title: "RoBERTa" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{RoBERTa} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include=FALSE} knitr::opts_chunk$set(echo = TRUE, eval = FALSE) ``` ## Intro First, we need to install ```blurr module``` for [Transformers](https://github.com/huggingface/transformers) integration. ```{r} reticulate::py_install('ohmeow-blurr',pip = TRUE) ``` ## Binary task Grab data for binary classification: ```{r} library(fastai) library(magrittr) library(zeallot) URLs_IMDB_SAMPLE() ``` Define task: ```{r} HF_TASKS_AUTO = HF_TASKS_AUTO() task = HF_TASKS_AUTO$SequenceClassification pretrained_model_name = "roberta-base" # "distilbert-base-uncased" "bert-base-uncased" c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task) ``` ``` Downloading: 100%|██████████| 481/481 [00:00<00:00, 277kB/s] Downloading: 100%|██████████| 899k/899k [00:01<00:00, 580kB/s] Downloading: 100%|██████████| 456k/456k [00:00<00:00, 471kB/s] Downloading: 100%|██████████| 501M/501M [03:11<00:00, 2.62MB/s] ``` ## Dataloader Create ```Learner``` [with Hugging Face data blocks](https://github.com/ohmeow/blurr): ```{r} imdb_df = data.table::fread('imdb_sample/texts.csv') blocks = list(HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), CategoryBlock()) dblock = DataBlock(blocks=blocks, get_x=ColReader('text'), get_y=ColReader('label'), splitter=ColSplitter(col='is_valid')) dls = dblock %>% dataloaders(imdb_df, bs=4) dls %>% one_batch() ``` ``` [[1]] [[1]]$input_ids tensor([[ 0, 4833, 3009, ..., 1916, 6, 2], [ 0, 1876, 13856, ..., 7, 47, 2], [ 0, 2647, 6, ..., 6, 61, 2], [ 0, 20, 2091, ..., 5779, 30, 2]], device='cuda:0') [[1]]$attention_mask tensor([[1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1], [1, 1, 1, ..., 1, 1, 1]], device='cuda:0') [[2]] TensorCategory([0, 1, 0, 0], device='cuda:0') ``` ## RoBERTa model Wrap model: ```{r} model = HF_BaseModelWrapper(hf_model) learn = Learner(dls, model, opt_func=partial(Adam, decouple_wd=TRUE), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=HF_BaseModelCallback(), splitter=hf_splitter()) learn$create_opt() learn$freeze() learn %>% summary() ``` ``` epoch train_loss valid_loss accuracy time ------ ----------- ----------- --------- ------ HF_BaseModelWrapper (Input shape: 4 x 512) ================================================================ Layer (type) Output Shape Param # Trainable ================================================================ Embedding 4 x 512 x 768 38,603,520 False ________________________________________________________________ Embedding 4 x 512 x 768 394,752 False ________________________________________________________________ Embedding 4 x 512 x 768 768 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ Dropout 4 x 12 x 512 x 512 0 False ________________________________________________________________ Linear 4 x 512 x 768 590,592 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 512 x 3072 2,362,368 False ________________________________________________________________ Linear 4 x 512 x 768 2,360,064 False ________________________________________________________________ LayerNorm 4 x 512 x 768 1,536 True ________________________________________________________________ Dropout 4 x 512 x 768 0 False ________________________________________________________________ Linear 4 x 768 590,592 True ________________________________________________________________ Dropout 4 x 768 0 False ________________________________________________________________ Linear 4 x 2 1,538 True ________________________________________________________________ Total params: 124,647,170 Total trainable params: 630,530 Total non-trainable params: 124,016,640 Optimizer used: functools.partial(.python_function at 0x7fd850db18c8>, decouple_wd=True) Loss function: FlattenedLoss of CrossEntropyLoss() Model frozen up to parameter group #2 Callbacks: - TrainEvalCallback - Recorder - ProgressCallback - HF_BaseModelCallback ``` ## Conclusion Train and predict: ```{r} result = learn %>% fit_one_cycle(3, lr_max=1e-3) learn %>% predict(imdb_df$text[1:4]) ```