--- title: "Text-summarization" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Text-summarization} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include=FALSE} knitr::opts_chunk$set(echo = TRUE, eval = FALSE) ``` ## Intro First, we need to install ```blurr module``` for [Transformers](https://github.com/huggingface/transformers) integration. ```{r} reticulate::py_install('ohmeow-blurr',pip = TRUE) ``` ## Dataset Get dataset from the link: ```{r} library(fastai) library(magrittr) library(zeallot) cnndm_df = data.table::fread('https://raw.githubusercontent.com/ohmeow/blurr/master/nbs/cnndm_sample.csv') ``` ## Pretrained model Get pretrained BART model from ```transformers```: ```{r} transformers = transformers() BartForConditionalGeneration = transformers$BartForConditionalGeneration pretrained_model_name = "facebook/bart-large-cnn" c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name,model_cls=BartForConditionalGeneration) ``` ## Datablock Create datablock and see batch: ```{r} before_batch_tfm = HF_SummarizationBeforeBatchTransform(hf_arch, hf_tokenizer, max_length=c(256, 130)) blocks = list(HF_Text2TextBlock(before_batch_tfms=before_batch_tfm, input_return_type=HF_SummarizationInput), noop()) dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter()) dls = dblock %>% dataloaders(cnndm_df, bs=2) dls %>% one_batch() ``` ``` [[1]] [[1]]$input_ids tensor([[ 0, 36, 16256, 43, 480, 2193, 7, 62, 7, 158, 135, 9, 70, 684, 4707, 6, 1625, 16, 4984, 25, 65, 9, 5, 144, 39865, 32273, 3806, 15, 5, 5518, 4, 20, 9544, 3455, 9, 2147, 464, 8, 1050, 29779, 26284, 15, 1632, 11534, 32, 6, 959, 6, 5608, 5, 8066, 9, 5, 247, 18, 4066, 7892, 4, 178, 89, 16, 10, 372, 432, 7, 2217, 4, 96, 5, 315, 3076, 9356, 4928, 36, 4154, 9662, 43, 623, 12978, 23853, 2521, 18, 889, 9, 10721, 625, 32273, 749, 1625, 5546, 365, 212, 4, 20, 889, 3372, 10, 333, 9, 601, 749, 14, 19655, 5, 1647, 9, 5, 3875, 18, 4707, 8, 32, 3891, 1687, 2778, 39865, 32273, 4, 1740, 63, 23491, 28919, 11, 5, 7599, 3939, 7, 63, 10602, 41166, 1634, 11, 12718, 1115, 281, 8, 5, 854, 3964, 21785, 14113, 8, 63, 38905, 8, 21950, 19947, 11, 5, 1926, 6, 1625, 10902, 41, 5784, 4066, 3143, 9, 39456, 8, 856, 25729, 4, 993, 195, 5243, 66, 9, 262, 1360, 38950, 1848, 4707, 303, 11, 1625, 480, 5, 144, 11, 143, 247, 480, 64, 129, 28, 13590, 624, 63, 7562, 4, 85, 16, 184, 7, 39179, 3505, 9, 27772, 6, 28482, 4707, 9, 7723, 6, 112, 6, 6115, 17576, 9, 7723, 8, 973, 6, 151, 1380, 14868, 9, 3451, 4, 221, 2839, 8367, 102, 6, 10, 786, 12, 7699, 1651, 14, 1364, 7, 3720, 8360, 8, 5068, 709, 11, 1625, 6, 34, 3919, 411, 4707, 61, 24, 161, 7648, 2072, 5, 1272, 2713, 30, 5, 2], [ 0, 36, 16256, 43, 480, 6871, 24, 1655, 22049, 50, 41571, 1896, 54, 24509, 13, 244, 5, 363, 5, 601, 12, 180, 12, 279, 1896, 21, 738, 1462, 116, 280, 115, 6723, 15, 61, 985, 5, 3940, 2046, 4, 1868, 22049, 18, 8, 1896, 18, 8826, 2327, 117, 28946, 273, 11, 2559, 461, 4961, 25, 7, 1060, 28604, 2236, 16, 1317, 11347, 148, 10, 7158, 486, 31, 14, 902, 973, 6, 1125, 6, 363, 11, 23058, 6, 1261, 35, 4028, 26, 24, 21, 69, 979, 4, 280, 34920, 480, 19, 5767, 3809, 1243, 21605, 13875, 24, 21, 69, 979, 6, 41571, 6, 54, 16670, 66, 6, 150, 15151, 2459, 22049, 26, 24, 21, 69, 979, 6, 1655, 6, 54, 21, 16600, 71, 145, 4487, 30, 5, 6066, 480, 21, 1353, 7, 273, 18, 461, 7069, 6, 8, 1353, 7, 5, 200, 12, 5743, 1900, 403, 25451, 11, 1353, 1261, 4, 22049, 34, 4407, 45, 2181, 8, 1695, 37, 738, 5, 7044, 11, 1403, 12, 21409, 4, 20, 7158, 486, 702, 2330, 11, 461, 15, 273, 6, 39, 3969, 2026, 6, 124, 62, 49, 19395, 14, 24, 21, 1896, 6, 8, 45, 49, 3653, 6, 54, 21, 5, 29593, 368, 4, 4500, 4945, 628, 273, 1390, 6, 15151, 2459, 22049, 26, 79, 21, 686, 1655, 21, 5, 65, 16600, 4, 2612, 116, 41039, 10105, 37, 18, 127, 979, 39732, 264, 7173, 41039, 1250, 9, 5, 1065, 48149, 77, 553, 549, 79, 56, 655, 137, 1317, 69, 1655, 22049, 7923, 24120, 50, 8930, 66, 13, 244, 4, 2]], device='cuda:0') [[1]]$attention_mask tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0') [[1]]$decoder_input_ids tensor([[ 0, 1625, 4452, 7, 62, 7, 158, 135, 9, 70, 684, 4707, 15, 3875, 479, 50118, 243, 16, 184, 7, 39179, 3505, 9, 27772, 6, 28482, 5103, 4707, 8, 973, 6, 151, 3505, 9, 3451, 479, 50118, 33837, 709, 8, 2147, 464, 16, 9405, 10, 380, 8793, 15, 63, 28809, 479, 50118, 133, 3274, 9587, 16, 223, 1856, 11, 14117, 9, 145, 5, 247, 18, 632, 7648, 479, 2, 1, 1, 1, 1, 1, 1, 1], [ 0, 5178, 35, 83, 1443, 2470, 161, 97, 1283, 6, 45, 5, 7158, 486, 6, 40, 3094, 5, 403, 479, 50118, 5341, 35, 83, 2470, 13, 1896, 18, 284, 161, 37, 4265, 5, 3940, 40, 465, 22049, 2181, 479, 50118, 534, 15356, 2459, 22049, 161, 79, 2215, 5, 28604, 2236, 16, 14, 9, 69, 979, 479, 50118, 30541, 6, 41571, 1896, 18, 18631, 1843, 26, 14, 24, 69, 979, 18, 2236, 15, 5, 7158, 486, 479]], device='cuda:0') [[1]]$labels tensor([[ 1625, 4452, 7, 62, 7, 158, 135, 9, 70, 684, 4707, 15, 3875, 479, 50118, 243, 16, 184, 7, 39179, 3505, 9, 27772, 6, 28482, 5103, 4707, 8, 973, 6, 151, 3505, 9, 3451, 479, 50118, 33837, 709, 8, 2147, 464, 16, 9405, 10, 380, 8793, 15, 63, 28809, 479, 50118, 133, 3274, 9587, 16, 223, 1856, 11, 14117, 9, 145, 5, 247, 18, 632, 7648, 479, 2, -100, -100, -100, -100, -100, -100, -100, -100], [ 5178, 35, 83, 1443, 2470, 161, 97, 1283, 6, 45, 5, 7158, 486, 6, 40, 3094, 5, 403, 479, 50118, 5341, 35, 83, 2470, 13, 1896, 18, 284, 161, 37, 4265, 5, 3940, 40, 465, 22049, 2181, 479, 50118, 534, 15356, 2459, 22049, 161, 79, 2215, 5, 28604, 2236, 16, 14, 9, 69, 979, 479, 50118, 30541, 6, 41571, 1896, 18, 18631, 1843, 26, 14, 24, 69, 979, 18, 2236, 15, 5, 7158, 486, 479, 2]], device='cuda:0') [[2]] tensor([[ 1625, 4452, 7, 62, 7, 158, 135, 9, 70, 684, 4707, 15, 3875, 479, 50118, 243, 16, 184, 7, 39179, 3505, 9, 27772, 6, 28482, 5103, 4707, 8, 973, 6, 151, 3505, 9, 3451, 479, 50118, 33837, 709, 8, 2147, 464, 16, 9405, 10, 380, 8793, 15, 63, 28809, 479, 50118, 133, 3274, 9587, 16, 223, 1856, 11, 14117, 9, 145, 5, 247, 18, 632, 7648, 479, 2, -100, -100, -100, -100, -100, -100, -100, -100], [ 5178, 35, 83, 1443, 2470, 161, 97, 1283, 6, 45, 5, 7158, 486, 6, 40, 3094, 5, 403, 479, 50118, 5341, 35, 83, 2470, 13, 1896, 18, 284, 161, 37, 4265, 5, 3940, 40, 465, 22049, 2181, 479, 50118, 534, 15356, 2459, 22049, 161, 79, 2215, 5, 28604, 2236, 16, 14, 9, 69, 979, 479, 50118, 30541, 6, 41571, 1896, 18, 18631, 1843, 26, 14, 24, 69, 979, 18, 2236, 15, 5, 7158, 486, 479, 2]], device='cuda:0') ``` ## Train Define a ```Learner``` object and train the model: ```{r} text_gen_kwargs = hf_config$task_specific_params['summarization'][[1]] text_gen_kwargs['max_length'] = 130L; text_gen_kwargs['min_length'] = 30L text_gen_kwargs model = HF_BaseModelWrapper(hf_model) model_cb = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs) learn = Learner(dls, model, opt_func=partial(Adam), loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss() cbs=model_cb, splitter=partial(summarization_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16() learn$create_opt() learn$freeze() learn %>% fit_one_cycle(1, lr_max=4e-5) ``` ## Conclusion Predict a new text: ```{r} test_article = c("About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat her, and took off for the French border. The other gunmen followed into France, which is only about 100 meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan contributed to this report.") ``` ```{r} outputs = learn$blurr_summarize(test_article, num_return_sequences=3L) cat(outputs) ``` ``` Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say . The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel . There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says . A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death . Swiss authorities are working closely with French authorities, police chief says . Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say . The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel . There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says . A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death . Swiss authorities are working closely with French authorities, an officer says. Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say . The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel . There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says . A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death . ```