.. _sec_word2vec_pretraining: 预训练word2vec ============== 我们继续实现 :numref:`sec_word2vec`\ 中定义的跳元语法模型。然后,我们将在PTB数据集上使用负采样预训练word2vec。首先,让我们通过调用\ ``d2l.load_data_ptb``\ 函数来获得该数据集的数据迭代器和词表,该函数在 :numref:`sec_word2vec_data`\ 中进行了描述。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math from mxnet import autograd, gluon, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() batch_size, max_window_size, num_noise_words = 512, 5, 5 data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size, num_noise_words) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math import torch from torch import nn from d2l import torch as d2l batch_size, max_window_size, num_noise_words = 512, 5, 5 data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size, num_noise_words) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import warnings from d2l import paddle as d2l warnings.filterwarnings("ignore") import math import paddle from paddle import nn batch_size, max_window_size, num_noise_words = 512, 5, 5 data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size, num_noise_words) .. raw:: html
.. raw:: html
跳元模型 -------- 我们通过嵌入层和批量矩阵乘法实现了跳元模型。首先,让我们回顾一下嵌入层是如何工作的。 嵌入层 ~~~~~~ 如 :numref:`sec_seq2seq`\ 中所述,嵌入层将词元的索引映射到其特征向量。该层的权重是一个矩阵,其行数等于字典大小(\ ``input_dim``\ ),列数等于每个标记的向量维数(\ ``output_dim``\ )。在词嵌入模型训练之后,这个权重就是我们所需要的。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed = nn.Embedding(input_dim=20, output_dim=4) embed.initialize() embed.weight .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [07:35:41] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Parameter embedding0_weight (shape=(20, 4), dtype=float32) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed = nn.Embedding(num_embeddings=20, embedding_dim=4) print(f'Parameter embedding_weight ({embed.weight.shape}, ' f'dtype={embed.weight.dtype})') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Parameter embedding_weight (torch.Size([20, 4]), dtype=torch.float32) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed = nn.Embedding(num_embeddings=20, embedding_dim=4) print(f'Parameter embedding_weight ({embed.weight.shape}, ' f'dtype={embed.weight.dtype})') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output W0818 09:42:26.754757 14561 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.8, Runtime API Version: 11.8 W0818 09:42:26.785637 14561 gpu_resources.cc:91] device: 0, cuDNN Version: 8.7. Parameter embedding_weight ([20, 4], dtype=paddle.float32) .. raw:: html
.. raw:: html
嵌入层的输入是词元(词)的索引。对于任何词元索引\ :math:`i`\ ,其向量表示可以从嵌入层中的权重矩阵的第\ :math:`i`\ 行获得。由于向量维度(\ ``output_dim``\ )被设置为4,因此当小批量词元索引的形状为(2,3)时,嵌入层返回具有形状(2,3,4)的向量。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = np.array([[1, 2, 3], [4, 5, 6]]) embed(x) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[[ 0.01438687, 0.05011239, 0.00628365, 0.04861524], [-0.01068833, 0.01729892, 0.02042518, -0.01618656], [-0.00873779, -0.02834515, 0.05484822, -0.06206018]], [[ 0.06491279, -0.03182812, -0.01631819, -0.00312688], [ 0.0408415 , 0.04370362, 0.00404529, -0.0028032 ], [ 0.00952624, -0.01501013, 0.05958354, 0.04705103]]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = torch.tensor([[1, 2, 3], [4, 5, 6]]) embed(x) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([[[-1.4754, -0.3612, -0.4246, 0.5805], [-0.3160, 0.8830, 0.5328, 0.2179], [-0.0378, -0.5559, 1.4525, 0.6230]], [[ 0.0829, -1.0549, 0.6381, 0.7886], [-0.3862, -0.1291, 0.4160, -0.6710], [-0.4056, 0.0370, -0.6308, -0.2865]]], grad_fn=) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) embed(x) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Tensor(shape=[2, 3, 4], dtype=float32, place=Place(gpu:0), stop_gradient=False, [[[-0.44344363, -0.46687275, -0.41477701, 0.29907674], [-0.26555616, 0.33917296, 0.21654493, -0.15929097], [-0.13394782, 0.20542145, -0.10213378, -0.09951779]], [[ 0.17691875, 0.04455477, -0.29605603, 0.07824123], [-0.21901280, 0.03270376, -0.35716733, 0.17594659], [ 0.11768246, -0.34967864, -0.33694351, 0.40039498]]]) .. raw:: html
.. raw:: html
定义前向传播 ~~~~~~~~~~~~ 在前向传播中,跳元语法模型的输入包括形状为(批量大小,1)的中心词索引\ ``center``\ 和形状为(批量大小,\ ``max_len``\ )的上下文与噪声词索引\ ``contexts_and_negatives``\ ,其中\ ``max_len``\ 在 :numref:`subsec_word2vec-minibatch-loading`\ 中定义。这两个变量首先通过嵌入层从词元索引转换成向量,然后它们的批量矩阵相乘(在 :numref:`subsec_batch_dot`\ 中描述)返回形状为(批量大小,1,\ ``max_len``\ )的输出。输出中的每个元素是中心词向量和上下文或噪声词向量的点积。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def skip_gram(center, contexts_and_negatives, embed_v, embed_u): v = embed_v(center) u = embed_u(contexts_and_negatives) pred = npx.batch_dot(v, u.swapaxes(1, 2)) return pred .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def skip_gram(center, contexts_and_negatives, embed_v, embed_u): v = embed_v(center) u = embed_u(contexts_and_negatives) pred = torch.bmm(v, u.permute(0, 2, 1)) return pred .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def skip_gram(center, contexts_and_negatives, embed_v, embed_u): v = embed_v(center) u = embed_u(contexts_and_negatives) pred = paddle.bmm(v, u.transpose(perm=[0, 2, 1])) return pred .. raw:: html
.. raw:: html
让我们为一些样例输入打印此\ ``skip_gram``\ 函数的输出形状。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python skip_gram(np.ones((2, 1)), np.ones((2, 4)), embed, embed).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (2, 1, 4) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output torch.Size([2, 1, 4]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python skip_gram(paddle.ones((2, 1), dtype='int64'), paddle.ones((2, 4), dtype='int64'), embed, embed).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [2, 1, 4] .. raw:: html
.. raw:: html
训练 ---- 在训练带负采样的跳元模型之前,我们先定义它的损失函数。 二元交叉熵损失 ~~~~~~~~~~~~~~ 根据 :numref:`subsec_negative-sampling`\ 中负采样损失函数的定义,我们将使用二元交叉熵损失。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python loss = gluon.loss.SigmoidBCELoss() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class SigmoidBCELoss(nn.Module): # 带掩码的二元交叉熵损失 def __init__(self): super().__init__() def forward(self, inputs, target, mask=None): out = nn.functional.binary_cross_entropy_with_logits( inputs, target, weight=mask, reduction="none") return out.mean(dim=1) loss = SigmoidBCELoss() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class SigmoidBCELoss(nn.Layer): # 带掩码的二元交叉熵损失 def __init__(self): super().__init__() def forward(self, inputs, target, mask=None): out = nn.functional.binary_cross_entropy_with_logits( logit=inputs, label=target, weight=mask, reduction="none") return out.mean(axis=1) loss = SigmoidBCELoss() .. raw:: html
.. raw:: html
回想一下我们在 :numref:`subsec_word2vec-minibatch-loading`\ 中对掩码变量和标签变量的描述。下面计算给定变量的二进制交叉熵损失。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python pred = np.array([[1.1, -2.2, 3.3, -4.4]] * 2) label = np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]) mask = np.array([[1, 1, 1, 1], [1, 1, 0, 0]]) loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([0.9352101, 1.8462093]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2) label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]) mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]]) loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([0.9352, 1.8462]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python pred = paddle.to_tensor([[1.1, -2.2, 3.3, -4.4]] * 2) label = paddle.to_tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]) mask = paddle.to_tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype='float32') loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.93521005, 1.84620929]) .. raw:: html
.. raw:: html
下面显示了如何使用二元交叉熵损失中的Sigmoid激活函数(以较低效率的方式)计算上述结果。我们可以将这两个输出视为两个规范化的损失,在非掩码预测上进行平均。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def sigmd(x): return -math.log(1 / (1 + math.exp(-x))) print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}') print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 0.9352 1.8462 .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def sigmd(x): return -math.log(1 / (1 + math.exp(-x))) print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}') print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 0.9352 1.8462 .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def sigmd(x): return -math.log(1 / (1 + math.exp(-x))) print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}') print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 0.9352 1.8462 .. raw:: html
.. raw:: html
初始化模型参数 ~~~~~~~~~~~~~~ 我们定义了两个嵌入层,将词表中的所有单词分别作为中心词和上下文词使用。字向量维度\ ``embed_size``\ 被设置为100。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed_size = 100 net = nn.Sequential() net.add(nn.Embedding(input_dim=len(vocab), output_dim=embed_size), nn.Embedding(input_dim=len(vocab), output_dim=embed_size)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed_size = 100 net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size), nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed_size = 100 net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size), nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size)) .. raw:: html
.. raw:: html
定义训练阶段代码 ~~~~~~~~~~~~~~~~ 训练阶段代码实现定义如下。由于填充的存在,损失函数的计算与以前的训练函数略有不同。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def train(net, data_iter, lr, num_epochs, device=d2l.try_gpu()): net.initialize(ctx=device, force_reinit=True) trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr}) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs]) # 规范化的损失之和,规范化的损失数 metric = d2l.Accumulator(2) for epoch in range(num_epochs): timer, num_batches = d2l.Timer(), len(data_iter) for i, batch in enumerate(data_iter): center, context_negative, mask, label = [ data.as_in_ctx(device) for data in batch] with autograd.record(): pred = skip_gram(center, context_negative, net[0], net[1]) l = (loss(pred.reshape(label.shape), label, mask) * mask.shape[1] / mask.sum(axis=1)) l.backward() trainer.step(batch_size) metric.add(l.sum(), l.size) if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1],)) print(f'loss {metric[0] / metric[1]:.3f}, ' f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}') .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def train(net, data_iter, lr, num_epochs, device=d2l.try_gpu()): def init_weights(m): if type(m) == nn.Embedding: nn.init.xavier_uniform_(m.weight) net.apply(init_weights) net = net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs]) # 规范化的损失之和,规范化的损失数 metric = d2l.Accumulator(2) for epoch in range(num_epochs): timer, num_batches = d2l.Timer(), len(data_iter) for i, batch in enumerate(data_iter): optimizer.zero_grad() center, context_negative, mask, label = [ data.to(device) for data in batch] pred = skip_gram(center, context_negative, net[0], net[1]) l = (loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.sum(axis=1) * mask.shape[1]) l.sum().backward() optimizer.step() metric.add(l.sum(), l.numel()) if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1],)) print(f'loss {metric[0] / metric[1]:.3f}, ' f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}') .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def train(net, data_iter, lr, num_epochs, device=d2l.try_gpu()): def init_weights(m): if type(m) == nn.Embedding: nn.initializer.XavierUniform(m.weight) net.apply(init_weights) optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=net.parameters()) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs]) # 规范化的损失之和,规范化的损失数 metric = d2l.Accumulator(2) for epoch in range(num_epochs): timer, num_batches = d2l.Timer(), len(data_iter) for i, batch in enumerate(data_iter): optimizer.clear_grad() center, context_negative, mask, label = [ paddle.to_tensor(data, place=device) for data in batch] pred = skip_gram(center, context_negative, net[0], net[1]) l = (loss(pred.reshape(label.shape), paddle.to_tensor(label, dtype='float32'), paddle.to_tensor(mask, dtype='float32')) / mask.sum(axis=1) * mask.shape[1]) l.sum().backward() optimizer.step() metric.add(l.sum(), l.numel()) if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1],)) print(f'loss {metric[0] / metric[1]:.3f}, ' f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}') .. raw:: html
.. raw:: html
现在,我们可以使用负采样来训练跳元模型。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python lr, num_epochs = 0.002, 5 train(net, data_iter, lr, num_epochs) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output loss 0.408, 103331.7 tokens/sec on gpu(0) .. figure:: output_word2vec-pretraining_d81279_123_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python lr, num_epochs = 0.002, 5 train(net, data_iter, lr, num_epochs) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output loss 0.410, 377799.5 tokens/sec on cuda:0 .. figure:: output_word2vec-pretraining_d81279_126_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python lr, num_epochs = 0.002, 5 train(net, data_iter, lr, num_epochs) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output loss 0.410, 366982.2 tokens/sec on Place(gpu:0) .. figure:: output_word2vec-pretraining_d81279_129_1.svg .. raw:: html
.. raw:: html
.. _subsec_apply-word-embed: 应用词嵌入 ---------- 在训练word2vec模型之后,我们可以使用训练好模型中词向量的余弦相似度来从词表中找到与输入单词语义最相似的单词。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def get_similar_tokens(query_token, k, embed): W = embed.weight.data() x = W[vocab[query_token]] # 计算余弦相似性。增加1e-9以获得数值稳定性 cos = np.dot(W, x) / np.sqrt(np.sum(W * W, axis=1) * np.sum(x * x) + \ 1e-9) topk = npx.topk(cos, k=k+1, ret_typ='indices').asnumpy().astype('int32') for i in topk[1:]: # 删除输入词 print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(i)}') get_similar_tokens('chip', 3, net[0]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output cosine sim=0.582: microprocessor cosine sim=0.533: intel cosine sim=0.531: dell .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def get_similar_tokens(query_token, k, embed): W = embed.weight.data x = W[vocab[query_token]] # 计算余弦相似性。增加1e-9以获得数值稳定性 cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9) topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32') for i in topk[1:]: # 删除输入词 print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(i)}') get_similar_tokens('chip', 3, net[0]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output cosine sim=0.773: microprocessor cosine sim=0.589: hitachi cosine sim=0.582: computers .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def get_similar_tokens(query_token, k, embed): W = embed.weight x = W[vocab[query_token]] # 计算余弦相似性。增加1e-9以获得数值稳定性 cos = paddle.mv(W, x) / paddle.sqrt(paddle.sum(W * W, axis=1) * paddle.sum(x * x) + 1e-9) topk = paddle.topk(cos, k=k+1)[1].numpy().astype('int32') for i in topk[1:]: # 删除输入词 print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(i)}') get_similar_tokens('chip', 3, net[0]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output cosine sim=0.770: intel cosine sim=0.693: microprocessor cosine sim=0.670: desktop .. raw:: html
.. raw:: html
小结 ---- - 我们可以使用嵌入层和二元交叉熵损失来训练带负采样的跳元模型。 - 词嵌入的应用包括基于词向量的余弦相似度为给定词找到语义相似的词。 练习 ---- 1. 使用训练好的模型,找出其他输入词在语义上相似的词。您能通过调优超参数来改进结果吗? 2. 当训练语料库很大时,在更新模型参数时,我们经常对当前小批量的\ *中心词*\ 进行上下文词和噪声词的采样。换言之,同一中心词在不同的训练迭代轮数可以有不同的上下文词或噪声词。这种方法的好处是什么?尝试实现这种训练方法。 .. raw:: html
mxnetpytorchpaddle
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html