---
title: "GPT2"
output:
rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{GPT2}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE, eval = FALSE)
```
## Intro
The [fastai](https://github.com/fastai/fastai) library simplifies training fast and accurate neural nets using modern best practices. See the fastai website to get started. The library is based on research into deep learning best practices undertaken at ```fast.ai```, and includes "out of the box" support for ```vision```, ```text```, ```tabular```, and ```collab``` (collaborative filtering) models.
## Dataset
Download and read data:
```{r}
library(fastai)
library(magrittr)
URLs_WIKITEXT()
path = 'wikitext-2'
train = data.table::fread(paste(path, 'train.csv', sep = '/'), header = FALSE, fill = TRUE)
test = data.table::fread(paste(path, 'test.csv', sep = '/'), header = FALSE, fill = TRUE)
df = rbind(train, test)
rm(train,test)
```
## Transformers
Improt library, get model and weights:
```{r}
tr = reticulate::import('transformers')
pretrained_weights = 'gpt2'
tokenizer = tr$GPT2TokenizerFast$from_pretrained(pretrained_weights)
model = tr$GPT2LMHeadModel$from_pretrained(pretrained_weights)
```
## Tokenize
Tokenize and place into list:
```{r}
tokenize = function(text) {
toks = tokenizer$tokenize(text)
tensor(tokenizer$convert_tokens_to_ids(toks))
}
tokenized = list()
for (i in 1:length(df$V1)) {
tokeniz = tokenize(df$V1[i])
tokenized = tokenized %>% append(tokeniz)
if(i %% 100 == 0 ) {
print(i)
}
}
```
Later split data:
```{r}
tot = 1:nrow(df)
tr_idx = sample(nrow(df), 0.8 * nrow(df))
ts_idx = tot[!tot %in% tr_idx]
splits = list(tr_idx, ts_idx)
```
## Dataloader
Prepare dataloader and train data:
> Note:
The HuggingFace model will return a tuple in outputs, with the actual predictions and some additional activations (should we want to use them in some regularization scheme). To work inside the fastai training loop, we will need to drop those using a Callback: we use those to alter the behavior of the training loop.
Here we need to write the event after_pred and replace self.learn.pred (which contains the predictions that will be passed to the loss function) by just its first element. In callbacks, there is a shortcut that lets you access any of the underlying Learner attributes so we can write `self$pred[0]` instead of `self$learn$pred[0]`. That shortcut only works for read access, not write, so we have to write `self$learn$pred` on the right side (otherwise we would set a pred attribute in the `Callback`).
```{r}
tls = TfmdLists(tokenized, TransformersTokenizer(tokenizer),
splits = splits,
dl_type = LMDataLoader())
bs = 8
sl = 100
dls = tls %>% dataloaders(bs = bs, seq_len = sl)
# Now, we are ready to create our Learner, which is a fastai object grouping data, model
# and loss function and handles model training or inference. Since we are in a language
#model setting, we pass perplexity as a metric, and we need to use the callback we just
# defined. Lastly, we use mixed precision to save every bit of memory we can (and if you
# have a modern GPU, it will also make training faster):
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(),
cbs = list(TransformersDropOutput()),
metrics = Perplexity())$to_fp16()
learn %>% fit_one_cycle(1, 1e-4)
```
```
epoch train_loss valid_loss perplexity time
------ ----------- ----------- ----------- ------
0 3.245887 3.301065 27.141541 07:40
1 3.065197 3.234682 25.398289 07:43
```
## Conclusion
Generate text:
```{r}
prompt = "\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn"
prompt_ids = tokenizer$encode(prompt)
inp = tensor(prompt_ids)[NULL]$cuda()
preds = learn$model$generate(inp, max_length = 80L, num_beams = 5L, temperature = 1.5)
tokenizer$decode(as.integer(preds[0]$cpu()$numpy()))
```
Result:
```
"\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn
@-@ like head. The unicorn is a member of the family, a group of .
The unicorn is a member of the family, a group of . The unicorn is a
member of the family, a group of"
```