当前位置: 首页 > news >正文

Keras深度学习框架实战(5):KerasNLP使用GPT2进行文本生成

1、KerasNLP与GPT2概述

KerasNLP的GPT2进行文本生成是一个基于深度学习的自然语言处理任务,它利用GPT-2模型来生成自然流畅的文本。以下是关于KerasNLP的GPT2进行文本生成的概述:

  • GPT-2模型介绍

  • GPT-2(Generative Pre-trained Transformer 2)是由OpenAI开发的一种基于Transformer模型的自然语言处理(NLP)模型,旨在生成自然流畅的文本。

  • 它是一种无监督学习模型,设计目标是能够理解人类语言的复杂性并模拟出自然的语言生成。

  • GPT-2具有大量的训练数据和强大的算法,可以生成自然流畅、准确的文本。

  • KerasNLP与GPT-2

  • KerasNLP是Keras的一个扩展库,提供了对NLP任务的便捷支持,包括文本生成。

  • 通过KerasNLP,可以方便地加载预训练的GPT-2模型,并用于文本生成任务。

  • 文本生成过程

  • 使用GPT2Tokenizer将输入的文本转换为模型可以理解的格式(即token IDs)。

  • 将token IDs作为输入传递给GPT-2模型。

  • 模型根据输入的上下文生成新的token IDs。

  • 使用GPT2Tokenizer将生成的token IDs解码回文本格式。

  • 特点与优势

  • GPT-2模型使用了大量的预训练参数,使其具有强大的表现力和泛化能力。

  • 可以生成各种类型的文本,如新闻、故事、对话和代码等。

  • 与其他基于神经网络的语言模型相比,GPT-2具有许多独特的优点,如自监督学习方式和处理多种语言和任务的能力。

  • 性能与规模

  • GPT-2模型有多个版本,从小型到大型,以适应不同的计算资源和性能需求。

  • 参数数量从1.5亿到1.75亿不等,模型大小从0.5GB到1.5GB。

  • 使用示例

  • 可以通过KerasNLP提供的接口和预训练模型,轻松实现文本生成任务。

  • 可以通过修改输入文本和参数设置,生成具有不同风格和主题的文本。

  • 注意事项

  • 生成的文本可能不完全符合语法或逻辑,因为模型是基于统计语言模型进行预测的。

  • 在实际应用中,需要对生成的文本进行适当的后处理和筛选,以确保其质量和适用性。

综上所述,KerasNLP的GPT2为文本生成任务提供了强大的支持,通过利用预训练的GPT-2模型,可以轻松地生成自然流畅的文本。

在这个教程中,你将学习如何使用KerasNLP加载一个预训练的大型语言模型(LLM)——GPT-2模型(由OpenAI最初发明),将其微调到特定的文本风格,并基于用户的输入(也称为提示)生成文本。你还将学习GPT-2如何快速适应非英语语言,例如中文。

2、训练准备

运行硬件环境要求

运行GPT2模型需要较高的资源需求,请确保前往运行时 -> 更改运行环境类型并选择GPU硬件加速器运行环境(应具有>12G主机RAM和~15G GPU RAM),因为你将微调GPT-2模型。在CPU运行环境中运行此教程将需要数小时。

安装KerasNLP,选择后端并导入依赖项

这个示例使用Keras 3以便在"tensorflow"、"jax"或"torch"中任一环境中工作。KerasNLP内置了对Keras 3的支持,只需更改"KERAS_BACKEND"环境变量即可选择您所选择的后端。我们在下面选择JAX后端。

!pip install git+https://github.com/keras-team/keras-nlp.git -q
import os
os.environ["KERAS_BACKEND"] = "jax"  # 或"tensorflow"或"torch"import keras_nlp
import keras
import tensorflow as tf
import timekeras.mixed_precision.set_global_policy("mixed_float16")

生成大型语言模型(LLMs)

大型语言模型(LLMs)是一种机器学习模型,它们在大量文本数据上进行训练,以生成各种自然语言处理(NLP)任务的输出,如文本生成、问答和机器翻译。

生成性LLMs通常基于深度学习的神经网络,例如2017年由Google研究人员发明的Transformer架构,并且它们在大量文本数据上进行训练,通常涉及数十亿个单词。这些模型,如Google LaMDA和PaLM,是使用来自各种数据源的大型数据集进行训练的,这使它们能够为许多任务生成输出。生成性LLMs的核心是预测句子中的下一个词,通常称为因果语言模型预训练。通过这种方式,LLMs可以根据用户提示生成连贯的文本。有关语言模型的更多教学性讨论,可以参考斯坦福CS324 LLM课程。

KerasNLP

构建大型语言模型复杂且从头开始训练成本高昂。幸运的是,有预训练的LLMs可供立即使用。KerasNLP提供了大量的预训练检查点,让你可以无需自己训练即可尝试SOTA模型。

KerasNLP是一个自然语言处理库,它支持用户完成整个开发周期。KerasNLP提供了预训练模型和模块化的构建块,因此开发者可以轻松地重用预训练模型或堆叠自己的LLM。

简单来说,对于生成性LLM,KerasNLP提供了:

  • 带有generate()方法的预训练模型,例如keras_nlp.models.GPT2CausalLMkeras_nlp.models.OPTCausalLM
  • 实现生成算法(如Top-K、Beam和对比搜索)的Sampler类。这些samplers可用于使用自定义模型生成文本。

3 加载模型

3.1 加载预训练的GPT-2模型并生成一些文本

KerasNLP提供了许多预训练模型,如Google Bert和GPT-2。程序员可以在KerasNLP仓库中看到可用模型的列表。

加载GPT-2模型非常简单,如下所示:

# 为了加快训练和生成速度,我们使用长度为128的预处理器
# 而不是完整的长度1024。
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset("gpt2_base_en",sequence_length=128,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en", preprocessor=preprocessor
)

一旦模型加载完成,程序员就可以立即使用它来生成一些文本。运行下面的单元格来尝试一下。这就像调用一个单一的函数generate()一样简单:

start = time.time()output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")
print(output)end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
My trip to Yosemite was pretty awesome. The first time I went I didn't know how to go and it was pretty hard to get around. It was a bit like going on an adventure with a friend. The only things I could do were hike and climb the mountain. It's really cool to know you're not alone in this world. It's a lot of fun. I'm a little worried that I might not get to the top of the mountain in time to see the sunrise and sunset of the day. I think the weather is going to get a little warmer in the coming years.
This post is a little more in-depth on how to go on the trail. It covers how to hike on the Sierra Nevada, how to hike with the Sierra Nevada, how to hike in the Sierra Nevada, how to get to the top of the mountain, and how to get to the top with your own gear.
The Sierra Nevada is a very popular trail in Yosemite
TOTAL TIME ELAPSED: 25.36s

再试一个:

start = time.time()output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")
print(output)end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
That Italian restaurant is known for its delicious food, and the best part is that it has a full bar, with seating for a whole host of guests. And that's only because it's located at the heart of the neighborhood.
The menu at the Italian restaurant is pretty straightforward:The menu consists of three main dishes:Italian sausage with cheeseAnd the main menu consists of a few other things.There are two tables: the one that serves a menu of sausage and bolognese with cheese (the one that serves the menu of sausage and bolognese with cheese) and the one that serves the menu of sausage and bolognese with cheese. The two tables are also open 24 hours a day, 7 days a week.
TOTAL TIME ELAPSED: 1.55s

注意第二次调用的速度有多快。这是因为计算图在第一次运行中被XLA编译,并在第二次运行中在后台被重用。

生成的文本质量看起来还可以,但我们可以通过微调来改进它。


3.2 KerasNLP中的GPT-2模型的工具

接下来,我们将实际微调模型以更新其参数,但在此之前,让我们看看我们拥有的用于GPT2的全部工具。

GPT2的代码可以在这里找到。从概念上讲,GPT2CausalLM可以被分层分解为KerasNLP中的几个模块,所有这些模块都有一个from_preset()函数来加载预训练模型:

  • keras_nlp.models.GPT2Tokenizer: GPT2模型使用的分词器,它是一个字节对编码器。
  • keras_nlp.models.GPT2CausalLMPreprocessor: GPT2因果语言模型训练使用的预处理器。它进行分词以及其他预处理工作,如创建标签和附加结束标记。
  • keras_nlp.models.GPT2Backbone: GPT2模型,它是keras_nlp.layers.TransformerDecoder的堆叠。这通常只被称为GPT2
  • keras_nlp.models.GPT2CausalLM: 包装GPT2Backbone,它将GPT2Backbone的输出乘以嵌入矩阵以在词汇表标记上生成logits。

3.3 在Reddit数据集上微调

现在程序员已经了解了KerasNLP中的GPT-2模型,你可以更进一步,微调模型,以便它以特定的风格生成文本,短或长,严格或随意。在本文中,我们将使用Reddit数据集作为示例。

import tensorflow_datasets as tfdsreddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)

让我们看看Reddit TensorFlow数据集中的样本数据。有两个特征:

  • document:帖子的文本。
  • title:标题。
for document, title in reddit_ds:print(document.numpy())print(title.numpy())break
b"me and a friend decided to
go to the beach last sunday. we loaded up and headed out. we were about half way there when i decided that i was not leaving till i had seafood. now i'm not talking about red lobster. no friends i'm talking about a low country boil. i found the restaurant and got directions. i don't know if any of you have heard about the crab shack on tybee island but let me tell you it's worth it. we arrived and was seated quickly. we decided to get a seafood sampler for two and split it. the waitress bought it out on separate platters for us. the amount of food was staggering. two types of crab, shrimp, mussels, crawfish, andouille sausage, red potatoes, and corn on the cob. i managed to finish it and some of my friends crawfish and mussels. it was a day to be a fat ass. we finished paid for our food and headed to the beach. funny thing about seafood. it runs through me faster than a kenyan we arrived and walked around a bit. it was about 45min since we arrived at the beach when i felt a rumble from the depths of my stomach. i ignored it i didn't want my stomach to ruin our fun. i pushed down the feeling and continued. about 15min later the feeling was back and stronger than before. again i ignored it and continued. 5min later it felt like a nuclear reactor had just exploded in my stomach. i started running. i yelled to my friend to hurry the fuck up. running in sand is extremely hard if you did not know this. we got in his car and i yelled at him to floor it. my stomach was screaming and if he didn't hurry i was gonna have this baby in his car and it wasn't gonna be pretty. after a few red lights and me screaming like a woman in labor we made it to the store.i practically tore his car door open and ran inside. i ran to the bathroom opened the door and barely got my pants down before the dam burst and a flood of shit poured from my ass.i finished up when i felt something wet on my ass. i rubbed it thinking it was back splash. no, mass was covered in the after math of me abusing the toilet. i grabbed all the paper towels i could and gave my self a whores bath right there.i sprayed the bathroom down with the air freshener and left. an elderly lady walked in quickly and closed the door. i was just about to walk away when i heard gag. instead of walking i ran. i got to the car and told him to get the hell out of there."b'liking seafood'

在我们的例子中,我们正在对语言模型进行下一个词的预测,所以我们只需要’document’特征。

train_ds = (reddit_ds.map(lambda document, _: document).batch(32).cache().prefetch(tf.data.AUTOTUNE)
)

现在,你可以使用熟悉的fit()函数来微调模型。注意,preprocessor将在fit方法中自动调用,因为GPT2CausalLM是一个keras_nlp.models.Task实例。

这一步需要相当多的GPU内存,并且如果我们要将其训练到完全训练状态需要很长时间。在这里,我们只使用数据集的一部分进行演示。

train_ds = train_ds.take(500)
num_epochs = 1# 线性衰减的学习率。
learning_rate = keras.optimizers.schedules.PolynomialDecay(5e-5,decay_steps=train_ds.cardinality() * num_epochs,end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(optimizer=keras.optimizers.Adam(learning_rate),loss=loss,weighted_metrics=["accuracy"],
)
gpt2_lm.fit(train_ds, epochs=num_epochs)
 500/500 ██████████████████████████████████| 75s 120ms/step - accuracy: 0.3189 - loss: 3.3653

微调完成后,你可以再次使用相同的generate()函数生成文本。这一次,文本将更接近Reddit的写作风格,并且生成的长度将接近我们在训练集中预设的长度。

start = time.time()output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
I like basketball. it has the greatest shot of all time and the best shot of all time. i have to play a little bit more and get some practice time.today i got the opportunity to play in a tournament in a city that is very close to my school so i was excited to see how it would go. i had just been playing with a few other guys, so i thought i would go and play a couple games with them.after a few games i was pretty confident and confident in myself. i had just gotten the opportunity and had to get some practice time.so i go to the
TOTAL TIME ELAPSED: 21.13s

4、采样方法

在KerasNLP中,我们提供了几种采样方法,例如对比搜索、Top-K和束搜索。默认情况下,我们的GPT2CausalLM使用Top-k搜索,但您可以选择自己的采样方法。

就像优化器和激活函数一样,有两种方式来指定自定义的采样器:

  • 使用字符串标识符,如"greedy",您通过这种方式使用默认配置。
  • 传递一个keras_nlp.samplers.Sampler实例,您可以通过这种方式使用自定义配置。
# 使用字符串标识符。
gpt2_lm.compile(sampler="top_k")
output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)# 使用`Sampler`实例。`GreedySampler`往往会重复自身。
greedy_sampler = keras_nlp.samplers.GreedySampler()
gpt2_lm.compile(sampler=greedy_sampler)output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)
GPT-2 output:
I like basketball, and this is a pretty good one.first off, my wife is pretty good, she is a very good basketball player and she is really, really good at playing basketball.she has an amazing game called basketball, it is a pretty fun game.i play it on the couch.  i'm sitting there, watching the game on the couch.  my wife is playing with her phone.  she's playing on the phone with a bunch of people.my wife is sitting there and watching basketball.  she's sitting there watching
GPT-2 output:
I like basketball, but i don't like to play it.so i was playing basketball at my local high school, and i was playing with my friends.i was playing with my friends, and i was playing with my brother, who was playing basketball with his brother.so i was playing with my brother, and he was playing with his brother's brother.so i was playing with my brother, and he was playing with his brother's brother.so i was playing with my brother, and he was playing with his brother's brother.so i was playing with my brother, and he was playing with his brother's brother.so i was playing with my brother, and he was playing with his brother

5 在中文诗歌数据集上微调

我们也可以在非英语数据集上微调GPT2,接下来的部分说明了如何在中文诗歌数据集上微调GPT2,以教我们的模型成为诗人!

因为GPT2使用字节对编码器,而原始预训练数据集包含一些中文字符,我们可以使用原始词汇表在中文数据集上进行微调。

!# 加载中文诗歌数据集。
!git clone https://github.com/chinese-poetry/chinese-poetry.git
Cloning into 'chinese-poetry'...

从json文件中加载文本。我们仅出于演示目的使用《全唐诗》。

import os
import jsonpoem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):if ".json" not in file or "poet" not in file:continuefull_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)with open(full_filename, "r") as f:content = json.load(f)poem_collection.extend(content)paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]

让我们看看样本数据。

与Reddit示例类似,我们将其转换为TF数据集,并且只使用部分数据进行训练。

train_ds = (tf.data.Dataset.from_tensor_slices(paragraphs).batch(16).cache().prefetch(tf.data.AUTOTUNE)
)# 运行整个数据集需要很长时间,只取500条并运行1个epoch用于演示目的。
train_ds = train_ds.take(500)
num_epochs = 1learning_rate = keras.optimizers.schedules.PolynomialDecay(5e-4,decay_steps=train_ds.cardinality() * num_epochs,end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(optimizer=keras.optimizers.Adam(learning_rate),loss=loss,weighted_metrics=["accuracy"],
)
gpt2_lm.fit(train_ds, epochs=num_epochs)
 500/500 ██████████████████████████████████| 49s 71ms/step - accuracy: 0.2357 - loss: 2.8196

让我们检查结果!

output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
print(output)
昨夜雨疏风骤,爲臨江山院短靜。石淡山陵長爲羣,臨石山非處臨羣。美陪河埃聲爲羣,漏漏漏邊陵塘

6、源代码

"""shell
pip install git+https://github.com/keras-team/keras-nlp.git -q
"""import osos.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"import keras_nlp
import keras
import tensorflow as tf
import timekeras.mixed_precision.set_global_policy("mixed_float16")"""
## Introduction to Generative Large Language Models (LLMs)Large language models (LLMs) are a type of machine learning models that are
trained on a large corpus of text data to generate outputs for various natural
language processing (NLP) tasks, such as text generation, question answering,
and machine translation.Generative LLMs are typically based on deep learning neural networks, such as
the [Transformer architecture](https://arxiv.org/abs/1706.03762) invented by
Google researchers in 2017, and are trained on massive amounts of text data,
often involving billions of words. These models, such as Google [LaMDA](https://blog.google/technology/ai/lamda/)
and [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html),
are trained with a large dataset from various data sources which allows them to
generate output for many tasks. The core of Generative LLMs is predicting the
next word in a sentence, often referred as **Causal LM Pretraining**. In this
way LLMs can generate coherent text based on user prompts. For a more
pedagogical discussion on language models, you can refer to the
[Stanford CS324 LLM class](https://stanford-cs324.github.io/winter2022/lectures/introduction/).
""""""
## Introduction to KerasNLPLarge Language Models are complex to build and expensive to train from scratch.
Luckily there are pretrained LLMs available for use right away. [KerasNLP](https://keras.io/keras_nlp/)
provides a large number of pre-trained checkpoints that allow you to experiment
with SOTA models without needing to train them yourself.KerasNLP is a natural language processing library that supports users through
their entire development cycle. KerasNLP offers both pretrained models and
modularized building blocks, so developers could easily reuse pretrained models
or stack their own LLM.In a nutshell, for generative LLM, KerasNLP offers:- Pretrained models with `generate()` method, e.g.,`keras_nlp.models.GPT2CausalLM` and `keras_nlp.models.OPTCausalLM`.
- Sampler class that implements generation algorithms such as Top-K, Beam andcontrastive search. These samplers can be used to generate text withcustom models.
""""""
## Load a pre-trained GPT-2 model and generate some textKerasNLP provides a number of pre-trained models, such as [Google
Bert](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)
and [GPT-2](https://openai.com/research/better-language-models). You can see
the list of models available in the [KerasNLP repository](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/models).It's very easy to load the GPT-2 model as you can see below:
"""# To speed up training and generation, we use preprocessor of length 128
# instead of full length 1024.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset("gpt2_base_en",sequence_length=128,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en", preprocessor=preprocessor
)"""
Once the model is loaded, you can use it to generate some text right away. Run
the cells below to give it a try. It's as simple as calling a single function
*generate()*:
"""start = time.time()output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")
print(output)end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")"""
Try another one:
"""start = time.time()output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")
print(output)end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")"""
Notice how much faster the second call is. This is because the computational
graph is [XLA compiled](https://www.tensorflow.org/xla) in the 1st run and
re-used in the 2nd behind the scenes.The quality of the generated text looks OK, but we can improve it via
fine-tuning.
""""""
## More on the GPT-2 model from KerasNLPNext up, we will actually fine-tune the model to update its parameters, but
before we do, let's take a look at the full set of tools we have to for working
with for GPT2.The code of GPT2 can be found
[here](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/).
Conceptually the `GPT2CausalLM` can be hierarchically broken down into several
modules in KerasNLP, all of which have a *from_preset()* function that loads a
pretrained model:- `keras_nlp.models.GPT2Tokenizer`: The tokenizer used by GPT2 model, which is a[byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).
- `keras_nlp.models.GPT2CausalLMPreprocessor`: the preprocessor used by GPT2causal LM training. It does the tokenization along with other preprocessingworks such as creating the label and appending the end token.
- `keras_nlp.models.GPT2Backbone`: the GPT2 model, which is a stack of`keras_nlp.layers.TransformerDecoder`. This is usually just referred as`GPT2`.
- `keras_nlp.models.GPT2CausalLM`: wraps `GPT2Backbone`, it multiplies theoutput of `GPT2Backbone` by embedding matrix to generate logits overvocab tokens.
""""""
## Finetune on Reddit datasetNow you have the knowledge of the GPT-2 model from KerasNLP, you can take one
step further to finetune the model so that it generates text in a specific
style, short or long, strict or casual. In this tutorial, we will use reddit
dataset for example.
"""import tensorflow_datasets as tfdsreddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)"""
Let's take a look inside sample data from the reddit TensorFlow Dataset. There
are two features:- **__document__**: text of the post.
- **__title__**: the title."""for document, title in reddit_ds:print(document.numpy())print(title.numpy())break"""
In our case, we are performing next word prediction in a language model, so we
only need the 'document' feature.
"""train_ds = (reddit_ds.map(lambda document, _: document).batch(32).cache().prefetch(tf.data.AUTOTUNE)
)"""
Now you can finetune the model using the familiar *fit()* function. Note that
`preprocessor` will be automatically called inside `fit` method since
`GPT2CausalLM` is a `keras_nlp.models.Task` instance.This step takes quite a bit of GPU memory and a long time if we were to train
it all the way to a fully trained state. Here we just use part of the dataset
for demo purposes.
"""train_ds = train_ds.take(500)
num_epochs = 1# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(5e-5,decay_steps=train_ds.cardinality() * num_epochs,end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(optimizer=keras.optimizers.Adam(learning_rate),loss=loss,weighted_metrics=["accuracy"],
)gpt2_lm.fit(train_ds, epochs=num_epochs)"""
After fine-tuning is finished, you can again generate text using the same
*generate()* function. This time, the text will be closer to Reddit writing
style, and the generated length will be close to our preset length in the
training set.
"""start = time.time()output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")"""
## Into the Sampling MethodIn KerasNLP, we offer a few sampling methods, e.g., contrastive search,
Top-K and beam sampling. By default, our `GPT2CausalLM` uses Top-k search, but
you can choose your own sampling method.Much like optimizer and activations, there are two ways to specify your custom
sampler:- Use a string identifier, such as "greedy", you are using the default
configuration via this way.
- Pass a `keras_nlp.samplers.Sampler` instance, you can use custom configuration
via this way.
"""# Use a string identifier.
gpt2_lm.compile(sampler="top_k")
output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)# Use a `Sampler` instance. `GreedySampler` tends to repeat itself,
greedy_sampler = keras_nlp.samplers.GreedySampler()
gpt2_lm.compile(sampler=greedy_sampler)output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)"""
For more details on KerasNLP `Sampler` class, you can check the code
[here](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/samplers).
""""""
## Finetune on Chinese Poem DatasetWe can also finetune GPT2 on non-English datasets. For readers knowing Chinese,
this part illustrates how to fine-tune GPT2 on Chinese poem dataset to teach our
model to become a poet!Because GPT2 uses byte-pair encoder, and the original pretraining dataset
contains some Chinese characters, we can use the original vocab to finetune on
Chinese dataset.
""""""shell
# Load chinese poetry dataset.
git clone https://github.com/chinese-poetry/chinese-poetry.git
""""""
Load text from the json file. We only use《全唐诗》for demo purposes.
"""import os
import jsonpoem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):if ".json" not in file or "poet" not in file:continuefull_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)with open(full_filename, "r") as f:content = json.load(f)poem_collection.extend(content)paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]"""
Let's take a look at sample data.
"""print(paragraphs[0])"""
Similar as Reddit example, we convert to TF dataset, and only use partial data
to train.
"""train_ds = (tf.data.Dataset.from_tensor_slices(paragraphs).batch(16).cache().prefetch(tf.data.AUTOTUNE)
)# Running through the whole dataset takes long, only take `500` and run 1
# epochs for demo purposes.
train_ds = train_ds.take(500)
num_epochs = 1learning_rate = keras.optimizers.schedules.PolynomialDecay(5e-4,decay_steps=train_ds.cardinality() * num_epochs,end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(optimizer=keras.optimizers.Adam(learning_rate),loss=loss,weighted_metrics=["accuracy"],
)gpt2_lm.fit(train_ds, epochs=num_epochs)"""
Let's check the result!
"""output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
print(output)

7、总结

本文讨论了关于如何使用KerasNLP库来加载、微调和使用GPT-2模型进行文本生成。

  • 环境设置:首先介绍了如何在Colab上选择GPU硬件加速器运行环境,以便于进行GPT-2模型的微调。

  • 安装与配置:然后指导用户安装KerasNLP库,并根据需要选择后端(tensorflow、jax或torch)。

  • 大型语言模型(GPT-2)介绍:解释了大型语言模型的概念,以及GPT-2是如何在大量文本数据上进行预训练的。

  • KerasNLP库介绍:介绍了KerasNLP库的功能,包括提供预训练模型和模块化的构建块,以便开发者可以重用或堆叠自己的LLM。

  • 加载预训练的GPT-2模型:展示了如何加载预训练的GPT-2模型,并使用它生成文本。

  • 微调模型:教程接下来介绍了如何使用Reddit数据集对GPT-2模型进行微调,以生成特定风格的文本。

  • 采样方法:讨论了KerasNLP中提供的几种采样方法,如Top-K、Beam和对比搜索,并展示了如何使用这些采样方法。

  • 在中文诗歌数据集上微调:最后,教程还介绍了如何在非英语数据集(中文诗歌)上微调GPT-2模型,以教模型成为诗人。

整个文章提供了详细的代码示例和说明,旨在帮助用户了解如何使用KerasNLP库来使用和微调GPT-2模型,并展示了模型在不同领域的应用潜力。

相关文章:

  • 【2024年5月备考新增】】 考前篇(34)《必备资料(17) - 论文串讲-项目采购管理》
  • 单例模式(C语言)
  • B端数据看板,其实数据可以更美的。
  • 【人工智能】第六部分:ChatGPT的进一步发展和研究方向
  • 【C++ | 析构函数】类的析构函数详解
  • SQL语句练习每日5题(二)
  • JVM内存分析之JVM分区与介绍
  • Python使用trule库画小猪佩奇
  • JavaSE基础语法合集
  • 字符串转换为字节数组、16进制转换为base64、base64转换为字符串数组、base64转换为16进制(微信小程序)
  • 个人投资伦敦银应该学会辨别的回撤形态
  • 洛谷P3214 [HNOI2011] 卡农
  • 力扣283. 移动零
  • 数组和指针的联系(C语言)
  • 区块链学习记录01
  • 【知识碎片】第三方登录弹窗效果
  • 2018天猫双11|这就是阿里云!不止有新技术,更有温暖的社会力量
  • EventListener原理
  • Gradle 5.0 正式版发布
  • nodejs调试方法
  • Octave 入门
  • PyCharm搭建GO开发环境(GO语言学习第1课)
  • React as a UI Runtime(五、列表)
  • spring cloud gateway 源码解析(4)跨域问题处理
  • SQLServer之索引简介
  • vue脚手架vue-cli
  • 从零开始的webpack生活-0x009:FilesLoader装载文件
  • 前端每日实战:70# 视频演示如何用纯 CSS 创作一只徘徊的果冻怪兽
  • 微信开源mars源码分析1—上层samples分析
  • 为物联网而生:高性能时间序列数据库HiTSDB商业化首发!
  • 再谈express与koa的对比
  • ​2021半年盘点,不想你错过的重磅新书
  • ​什么是bug?bug的源头在哪里?
  • ​总结MySQL 的一些知识点:MySQL 选择数据库​
  • !$boo在php中什么意思,php前戏
  • #Linux(帮助手册)
  • (二)Pytorch快速搭建神经网络模型实现气温预测回归(代码+详细注解)
  • (二)基于wpr_simulation 的Ros机器人运动控制,gazebo仿真
  • (翻译)Quartz官方教程——第一课:Quartz入门
  • (三) prometheus + grafana + alertmanager 配置Redis监控
  • (四)Controller接口控制器详解(三)
  • .net core 微服务_.NET Core 3.0中用 Code-First 方式创建 gRPC 服务与客户端
  • .NET命令行(CLI)常用命令
  • .NET上SQLite的连接
  • .Net语言中的StringBuilder:入门到精通
  • /boot 内存空间不够
  • @FeignClient 调用另一个服务的test环境,实际上却调用了另一个环境testone的接口,这其中牵扯到k8s容器外容器内的问题,注册到eureka上的是容器外的旧版本...
  • @value 静态变量_Python彻底搞懂:变量、对象、赋值、引用、拷贝
  • [15] 使用Opencv_CUDA 模块实现基本计算机视觉程序
  • [C#]科学计数法(scientific notation)显示为正常数字
  • [C/C++]关于C++11中的std::move和std::forward
  • [C++]C++类基本语法
  • [CareerCup] 2.1 Remove Duplicates from Unsorted List 移除无序链表中的重复项
  • [CTO札记]盛大文学公司名称对联
  • [HTTP]HTTP协议的状态码