diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index caa95631f7..96f2472eed 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -96,6 +96,18 @@ jobs: -word_vec_size 5 -report_every 5 \ -coverage_attn true -lambda_coverage 0.1 \ -rnn_size 10 -train_steps 10 + - name: Test Transformer training with pseudo self attention + run : | + python train.py \ + -config data/align_data.yaml \ + -src_vocab /tmp/onmt.vocab.src \ + -tgt_vocab /tmp/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -max_generator_batches 0 \ + -encoder_type transformer -decoder_type transformer_lm_psa \ + -layers 4 -word_vec_size 16 -rnn_size 16 -heads 2 -transformer_ff 64 \ + -report_every 5 -train_steps 10 - name: Test Transformer training with align run: | python train.py \ diff --git a/onmt/decoders/__init__.py b/onmt/decoders/__init__.py index 2b9a7acd34..1e50b7cd96 100644 --- a/onmt/decoders/__init__.py +++ b/onmt/decoders/__init__.py @@ -1,13 +1,32 @@ """Module defining decoders.""" -from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \ - StdRNNDecoder -from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder from onmt.decoders.cnn_decoder import CNNDecoder +from onmt.decoders.decoder import ( + DecoderBase, + InputFeedRNNDecoder, + StdRNNDecoder, +) +from onmt.decoders.transformer import ( + TransformerDecoder, + TransformerLMDecoder, + TransformerLMPseudoSelfAttentionDecoder, +) +str2dec = { + "rnn": StdRNNDecoder, + "ifrnn": InputFeedRNNDecoder, + "cnn": CNNDecoder, + "transformer": TransformerDecoder, + "transformer_lm": TransformerLMDecoder, + "transformer_lm_psa": TransformerLMPseudoSelfAttentionDecoder, +} -str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder, - "cnn": CNNDecoder, "transformer": TransformerDecoder, - "transformer_lm": TransformerLMDecoder} - -__all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder", - "InputFeedRNNDecoder", "str2dec", "TransformerLMDecoder"] +__all__ = [ + "DecoderBase", + "TransformerDecoder", + "StdRNNDecoder", + "CNNDecoder", + "InputFeedRNNDecoder", + "str2dec", + "TransformerLMDecoder", + "TransformerLMPseudoSelfAttentionDecoder", +] diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index a50e4a8e9c..48ba0fc0a1 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -8,6 +8,7 @@ from onmt.decoders.decoder import DecoderBase from onmt.modules import MultiHeadedAttention, AverageAttention +from onmt.modules import MultiHeadedPseudoSelfAttention from onmt.modules.position_ffn import PositionwiseFeedForward from onmt.modules.position_ffn import ActivationFunction from onmt.utils.misc import sequence_mask @@ -68,10 +69,16 @@ def __init__( self.self_attn = AverageAttention( d_model, dropout=attention_dropout, aan_useffn=aan_useffn ) - - self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, - pos_ffn_activation_fn - ) + elif self_attn_type == "pseudo-self": + self.self_attn = MultiHeadedPseudoSelfAttention( + heads, + d_model, + dropout=attention_dropout, + max_relative_positions=max_relative_positions, + ) + self.feed_forward = PositionwiseFeedForward( + d_model, d_ff, dropout, pos_ffn_activation_fn + ) self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) self.full_context_alignment = full_context_alignment @@ -120,7 +127,8 @@ def update_dropout(self, dropout, attention_dropout): def _forward(self, *args, **kwargs): raise NotImplementedError - def _compute_dec_mask(self, tgt_pad_mask, future): + @staticmethod + def _compute_dec_mask(tgt_pad_mask, future): tgt_len = tgt_pad_mask.size(-1) if not future: # apply future_mask, result mask in (B, T, T) future_mask = torch.ones( @@ -253,7 +261,7 @@ def _forward( """ dec_mask = None - if inputs.size(1) > 1: + if step is None: # masking is necessary when sequence length is greater than one dec_mask = self._compute_dec_mask(tgt_pad_mask, future) @@ -548,7 +556,7 @@ def _forward( """ dec_mask = None - if inputs.size(1) > 1: + if step is None or inputs.size(1) > 1: # masking is necessary when sequence length is greater than one dec_mask = self._compute_dec_mask(tgt_pad_mask, future) @@ -693,3 +701,222 @@ def _init_cache(self, memory_bank=None): if isinstance(layer.self_attn, AverageAttention): raise NotImplementedError self.state["cache"]["layer_{}".format(i)] = layer_cache + + +class TransformerLMPseudoSelfAttentionDecoderLayer( + TransformerDecoderLayerBase +): + """Transformer Decoder only layer block in GPT style. + + .. mermaid:: + + graph LR + %% "*SubLayer" can be self-attn, src-attn or feed forward block + A(input) --> B[Norm] + B --> C["*SubLayer"] + C --> D[Drop] + D --> E((+)) + A --> E + E --> F(out) + + + Args: + See TransformerDecoderLayerBase + """ + + def _forward( + self, + inputs, + src_memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=None, + step=None, + future=False, + ): + """A naive forward pass for transformer decoder. + + # T: could be 1 in the case of stepwise decoding or tgt_len + + Args: + inputs (FloatTensor): ``(batch_size, T, model_dim)`` + tgt_pad_mask (bool): ``(batch_size, 1, T)`` + layer_cache (dict or None): cached layer info when stepwise decode + step (int or None): stepwise decoding counter + future (bool): If set True, do not apply future_mask. + + Returns: + (FloatTensor, FloatTensor): + + * output ``(batch_size, T, model_dim)`` + * attns ``(batch_size, head, T, T)`` + + """ + dec_mask = None + pseudo_mask = None + if step is None: + # masking is necessary when sequence length is greater than one + dec_mask = self._compute_dec_mask(tgt_pad_mask, future) + pseudo_mask = torch.cat( + [src_pad_mask.repeat(1, inputs.size(1), 1), dec_mask], axis=-1 + ) + else: + pseudo_mask = torch.cat( + ( + src_pad_mask.repeat(1, inputs.size(1), 1), + torch.zeros( + (inputs.size(0), inputs.size(1), step + 1), + dtype=torch.bool, + device=src_pad_mask.device, + ), + ), + axis=-1, + ) + inputs_norm = self.layer_norm_1(inputs) + + query, attns = self.self_attn( + src_memory_bank.transpose(0, 1), + inputs_norm, + mask=pseudo_mask, + layer_cache=layer_cache, + attn_type="self", + ) + + output = self.drop(query) + inputs + + output_feedforward = self.feed_forward(output) + + return output_feedforward, attns + + +class TransformerLMPseudoSelfAttentionDecoder(TransformerDecoderBase): + """The Transformer decoder from GPT-2 with pseudo self attention + + .. mermaid:: + + graph BT + A[input] + B[multi-head self-attn] + C[feed forward] + O[output] + A --> B + B --> C + C --> O + + + Args: + num_layers (int): number of decoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + copy_attn (bool): if using a separate copy attention + self_attn_type (str): type of self-attention scaled-dot, average + dropout (float): dropout in residual, self-attn(dot) and feed-forward + attention_dropout (float): dropout in context_attn (and self-attn(avg)) + embeddings (onmt.modules.Embeddings): + embeddings to use, should have positional encodings + max_relative_positions (int): + Max distance between inputs in relative positions representations + aan_useffn (bool): Turn on the FFN layer in the AAN decoder + """ + + def __init__( + self, + num_layers, + d_model, + heads, + d_ff, + copy_attn, + self_attn_type, + dropout, + attention_dropout, + embeddings, + max_relative_positions, + aan_useffn, + full_context_alignment=None, + alignment_layer=None, + alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, + ): + super(TransformerLMPseudoSelfAttentionDecoder, self).__init__( + d_model, copy_attn, embeddings, None + ) + self.transformer_layers = nn.ModuleList( + [ + TransformerLMPseudoSelfAttentionDecoderLayer( + d_model, + heads, + d_ff, + dropout, + attention_dropout, + self_attn_type="pseudo-self", + max_relative_positions=max_relative_positions, + aan_useffn=aan_useffn, + full_context_alignment=None, + alignment_heads=None, + pos_ffn_activation_fn=pos_ffn_activation_fn, + ) + for i in range(num_layers) + ] + ) + + def detach_state(self): + pass + + def forward(self, tgt, memory_bank=None, step=None, **kwargs): + """Decode, possibly stepwise.""" + if step == 0: + self._init_cache() + + tgt_words = tgt[:, :, 0].transpose(0, 1) + + emb = self.embeddings(tgt, step=step) + assert emb.dim() == 3 # len x batch x embedding_dim + + output = emb.transpose(0, 1).contiguous() + + pad_idx = self.embeddings.word_padding_idx + src_lens = kwargs["memory_lengths"] + src_max_len = self.state["src"].shape[0] + src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) + tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] + + with_align = kwargs.pop("with_align", False) + assert not with_align, "TransformerLMDecoder does not support align" + + for i, layer in enumerate(self.transformer_layers): + layer_cache = ( + self.state["cache"]["layer_{}".format(i)] + if step is not None + else None + ) + output, attn, _ = layer( + output, + memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=layer_cache, + step=step, + with_align=with_align, + ) + + output = self.layer_norm(output) + dec_outs = output.transpose(0, 1).contiguous() + attn = attn.transpose(0, 1).contiguous() + + attns = {"std": attn} + if self._copy: + attns["copy"] = attn + + # TODO change the way attns is returned dict => list or tuple (onnx) + return dec_outs, attns + + def _init_cache(self, memory_bank=None): + self.state["cache"] = {} + + for i, layer in enumerate(self.transformer_layers): + layer_cache = {"self_keys": None, "self_values": None, + "src_keys": None, "src_values": None} + if isinstance(layer.self_attn, AverageAttention): + raise NotImplementedError + self.state["cache"]["layer_{}".format(i)] = layer_cache diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 0e789e5774..bbf835d46f 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -3,16 +3,34 @@ from onmt.modules.gate import context_gate_factory, ContextGate from onmt.modules.global_attention import GlobalAttention from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention -from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ - CopyGeneratorLossCompute, CopyGeneratorLMLossCompute -from onmt.modules.multi_headed_attn import MultiHeadedAttention +from onmt.modules.copy_generator import ( + CopyGenerator, + CopyGeneratorLoss, + CopyGeneratorLossCompute, + CopyGeneratorLMLossCompute, +) +from onmt.modules.multi_headed_attn import ( + MultiHeadedAttention, + MultiHeadedPseudoSelfAttention, +) from onmt.modules.embeddings import Embeddings, PositionalEncoding from onmt.modules.weight_norm import WeightNormConv2d from onmt.modules.average_attn import AverageAttention -__all__ = ["Elementwise", "context_gate_factory", "ContextGate", - "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", - "CopyGeneratorLoss", "CopyGeneratorLossCompute", - "MultiHeadedAttention", "Embeddings", "PositionalEncoding", - "WeightNormConv2d", "AverageAttention", - "CopyGeneratorLMLossCompute"] +__all__ = [ + "Elementwise", + "context_gate_factory", + "ContextGate", + "GlobalAttention", + "ConvMultiStepAttention", + "CopyGenerator", + "CopyGeneratorLoss", + "CopyGeneratorLossCompute", + "MultiHeadedAttention", + "Embeddings", + "PositionalEncoding", + "WeightNormConv2d", + "AverageAttention", + "CopyGeneratorLMLossCompute", + "MultiHeadedPseudoSelfAttention", +] diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index a9b8b487e0..0faf74ec0b 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from onmt.utils.misc import generate_relative_positions_matrix,\ - relative_matmul +from onmt.utils.misc import generate_relative_positions_matrix, relative_matmul + # from onmt.utils.misc import aeq @@ -48,8 +48,9 @@ class MultiHeadedAttention(nn.Module): dropout (float): dropout parameter """ - def __init__(self, head_count, model_dim, dropout=0.1, - max_relative_positions=0): + def __init__( + self, head_count, model_dim, dropout=0.1, max_relative_positions=0 + ): assert model_dim % head_count == 0 self.dim_per_head = model_dim // head_count self.model_dim = model_dim @@ -57,12 +58,13 @@ def __init__(self, head_count, model_dim, dropout=0.1, super(MultiHeadedAttention, self).__init__() self.head_count = head_count - self.linear_keys = nn.Linear(model_dim, - head_count * self.dim_per_head) - self.linear_values = nn.Linear(model_dim, - head_count * self.dim_per_head) - self.linear_query = nn.Linear(model_dim, - head_count * self.dim_per_head) + self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) + self.linear_values = nn.Linear( + model_dim, head_count * self.dim_per_head + ) + self.linear_query = nn.Linear( + model_dim, head_count * self.dim_per_head + ) self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.final_linear = nn.Linear(model_dim, model_dim) @@ -72,10 +74,12 @@ def __init__(self, head_count, model_dim, dropout=0.1, if max_relative_positions > 0: vocab_size = max_relative_positions * 2 + 1 self.relative_positions_embeddings = nn.Embedding( - vocab_size, self.dim_per_head) + vocab_size, self.dim_per_head + ) - def forward(self, key, value, query, mask=None, - layer_cache=None, attn_type=None): + def forward( + self, key, value, query, mask=None, layer_cache=None, attn_type=None + ): """ Compute the context vector and the attention vectors. @@ -120,42 +124,49 @@ def forward(self, key, value, query, mask=None, def shape(x): """Projection.""" - return x.view(batch_size, -1, head_count, dim_per_head) \ - .transpose(1, 2) + return x.view(batch_size, -1, head_count, dim_per_head).transpose( + 1, 2 + ) def unshape(x): """Compute context.""" - return x.transpose(1, 2).contiguous() \ - .view(batch_size, -1, head_count * dim_per_head) + return ( + x.transpose(1, 2) + .contiguous() + .view(batch_size, -1, head_count * dim_per_head) + ) # 1) Project key, value, and query. if layer_cache is not None: if attn_type == "self": - query, key, value = self.linear_query(query),\ - self.linear_keys(query),\ - self.linear_values(query) + query, key, value = ( + self.linear_query(query), + self.linear_keys(query), + self.linear_values(query), + ) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: - key = torch.cat( - (layer_cache["self_keys"], key), - dim=2) + key = torch.cat((layer_cache["self_keys"], key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat( - (layer_cache["self_values"], value), - dim=2) + (layer_cache["self_values"], value), dim=2 + ) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif attn_type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: - key, value = self.linear_keys(key),\ - self.linear_values(value) + key, value = self.linear_keys(key), self.linear_values( + value + ) key = shape(key) value = shape(value) else: - key, value = layer_cache["memory_keys"],\ - layer_cache["memory_values"] + key, value = ( + layer_cache["memory_keys"], + layer_cache["memory_values"], + ) layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: @@ -169,14 +180,18 @@ def unshape(x): key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( - key_len, self.max_relative_positions, - cache=True if layer_cache is not None else False) + key_len, + self.max_relative_positions, + cache=True if layer_cache is not None else False, + ) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( - relative_positions_matrix.to(key.device)) + relative_positions_matrix.to(key.device) + ) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( - relative_positions_matrix.to(key.device)) + relative_positions_matrix.to(key.device) + ) query = shape(query) @@ -201,14 +216,174 @@ def unshape(x): # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) + context_original = torch.matmul(drop_attn, value) + + if self.max_relative_positions > 0 and attn_type == "self": + context = unshape( + context_original + + relative_matmul(drop_attn, relations_values, False) + ) + else: + context = unshape(context_original) + + output = self.final_linear(context) + # CHECK + # batch_, q_len_, d_ = output.size() + # aeq(q_len, q_len_) + # aeq(batch, batch_) + # aeq(d, d_) + + # Return multi-head attn + attns = attn.view(batch_size, head_count, query_len, key_len) + + return output, attns + + def update_dropout(self, dropout): + self.dropout.p = dropout + +class MultiHeadedPseudoSelfAttention(nn.Module): + def __init__( + self, head_count, model_dim, dropout=0.1, max_relative_positions=0 + ): + assert model_dim % head_count == 0 + self.dim_per_head = model_dim // head_count + self.model_dim = model_dim + + super(MultiHeadedPseudoSelfAttention, self).__init__() + self.head_count = head_count + + self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) + self.linear_values = nn.Linear( + model_dim, head_count * self.dim_per_head + ) + self.linear_keys_src = nn.Linear( + model_dim, head_count * self.dim_per_head + ) + self.linear_values_src = nn.Linear( + model_dim, head_count * self.dim_per_head + ) + self.linear_query = nn.Linear( + model_dim, head_count * self.dim_per_head + ) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.final_linear = nn.Linear(model_dim, model_dim) + + self.max_relative_positions = max_relative_positions + + if max_relative_positions > 0: + vocab_size = max_relative_positions * 2 + 1 + self.relative_positions_embeddings = nn.Embedding( + vocab_size, self.dim_per_head + ) + + def forward(self, src, tgt, mask=None, layer_cache=None, attn_type=None): + batch_size = tgt.size(0) + dim_per_head = self.dim_per_head + head_count = self.head_count + key_len = tgt.size(1) + query_len = tgt.size(1) + + def shape(x): + """Projection.""" + return x.view(batch_size, -1, head_count, dim_per_head).transpose( + 1, 2 + ) + + def unshape(x): + """Compute context.""" + return ( + x.transpose(1, 2) + .contiguous() + .view(batch_size, -1, head_count * dim_per_head) + ) + + if layer_cache is not None: + query, self_key, self_value = ( + self.linear_query(tgt), + self.linear_keys(tgt), + self.linear_values(tgt), + ) + self_key = shape(self_key) + self_value = shape(self_value) + if layer_cache["self_keys"] is not None: + self_key = torch.cat( + (layer_cache["self_keys"], self_key), dim=2 + ) + if layer_cache["self_values"] is not None: + self_value = torch.cat( + (layer_cache["self_values"], self_value), dim=2 + ) + if layer_cache["src_keys"] is None: + layer_cache["src_keys"] = shape(self.linear_keys_src(src)) + layer_cache["src_values"] = shape(self.linear_values_src(src)) + layer_cache["self_keys"] = self_key + layer_cache["self_values"] = self_value + key = torch.cat( + (layer_cache["src_keys"], layer_cache["self_keys"]), dim=2 + ) + value = torch.cat( + (layer_cache["src_values"], layer_cache["self_values"]), dim=2 + ) + else: + key = torch.cat( + (self.linear_keys_src(src), self.linear_keys(tgt)), dim=1 + ) + value = torch.cat( + (self.linear_values_src(src), self.linear_values(tgt)), dim=1 + ) + query = self.linear_query(tgt) + key = shape(key) + value = shape(value) + + if self.max_relative_positions > 0 and attn_type == "self": + key_len = key.size(2) + # 1 or key_len x key_len + relative_positions_matrix = generate_relative_positions_matrix( + key_len, + self.max_relative_positions, + cache=True if layer_cache is not None else False, + ) + # 1 or key_len x key_len x dim_per_head + relations_keys = self.relative_positions_embeddings( + relative_positions_matrix.to(key.device) + ) + # 1 or key_len x key_len x dim_per_head + relations_values = self.relative_positions_embeddings( + relative_positions_matrix.to(key.device) + ) + + query = shape(query) + + key_len = key.size(2) + query_len = query.size(2) + + # 2) Calculate and scale scores. + query = query / math.sqrt(dim_per_head) + # batch x num_heads x query_len x key_len + query_key = torch.matmul(query, key.transpose(2, 3)) + + if self.max_relative_positions > 0 and attn_type == "self": + scores = query_key + relative_matmul(query, relations_keys, True) + else: + scores = query_key + scores = scores.float() + + if mask is not None: + mask = mask.unsqueeze(1) # [B, 1, 1, T_values] + scores = scores.masked_fill(mask, -1e18) + + # 3) Apply attention dropout and compute context vectors. + attn = self.softmax(scores).to(query.dtype) + drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and attn_type == "self": - context = unshape(context_original - + relative_matmul(drop_attn, - relations_values, - False)) + context = unshape( + context_original + + relative_matmul(drop_attn, relations_values, False) + ) else: context = unshape(context_original) @@ -220,9 +395,7 @@ def unshape(x): # aeq(d, d_) # Return multi-head attn - attns = attn \ - .view(batch_size, head_count, - query_len, key_len) + attns = attn.view(batch_size, head_count, query_len, key_len) return output, attns diff --git a/onmt/opts.py b/onmt/opts.py index 6872e351a4..fd09ac4647 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -271,7 +271,8 @@ def model_opts(parser): "are experimental. Options are " "[rnn|brnn|ggnn|mean|transformer|cnn|transformer_lm].") group.add('--decoder_type', '-decoder_type', type=str, default='rnn', - choices=['rnn', 'transformer', 'cnn', 'transformer_lm'], + choices=['rnn', 'transformer', 'cnn', 'transformer_lm', + 'transformer_lm_psa'], help="Type of decoder layer to use. Non-RNN layers " "are experimental. Options are " "[rnn|transformer|cnn|transformer].") diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index b282cc7f1e..83a65fd8e3 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -117,6 +117,20 @@ ${PYTHON} onmt/bin/train.py \ [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} +echo -n " [+] Testing NMT training w/ pseudo self attention..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/align_data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -max_generator_batches 0 \ + -encoder_type transformer -decoder_type transformer_lm_psa \ + -layers 4 -word_vec_size 16 -rnn_size 16 -heads 2 -transformer_ff 64 \ + -report_every 5 -train_steps 10 >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} + echo -n " [+] Testing NMT training w/ align..." ${PYTHON} onmt/bin/train.py \ -config ${DATA_DIR}/align_data.yaml \ diff --git a/onmt/tests/test_base_transformer.py b/onmt/tests/test_base_transformer.py new file mode 100644 index 0000000000..0c0ca87f73 --- /dev/null +++ b/onmt/tests/test_base_transformer.py @@ -0,0 +1,83 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch + +from onmt.decoders.transformer import TransformerDecoder +from onmt.modules import Embeddings +from onmt.modules.position_ffn import ActivationFunction + + +class TestTransformerDecoder(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(42) + emb = Embeddings( + word_vec_size=100, + position_encoding=True, + feat_merge="concat", + feat_vec_exponent=0.7, + feat_vec_size=-1, + dropout=0, + word_padding_idx=1, + feat_padding_idx=[], + word_vocab_size=100, + feat_vocab_sizes=[], + sparse=False, + freeze_word_vecs=False, + ) + cls.transformer_decoder = TransformerDecoder( + num_layers=2, + d_model=100, + heads=2, + d_ff=100, + copy_attn=False, + self_attn_type="scaled-dot", + dropout=0, + attention_dropout=0, + embeddings=emb, + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=None, + alignment_layer=None, + alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, + ) + cls.memory_bank = torch.rand([58, 2, 100]) + cls.tgt = torch.randint(3, 99, [12, 2, 1]) + cls.src = torch.randint(3, 99, [58, 2, 1]) + cls.memory_lengths = torch.tensor([58, 58]) + cls.transformer_decoder.init_state( + cls.src, cls.memory_bank, cls.memory_bank + ) + + def test_transformer_caching_equals_no_caching( + self, + ): + dec_outs, _ = self.transformer_decoder( + self.tgt[1:3], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=None, + ) + dec_outs_step_0, _ = self.transformer_decoder( + self.tgt[1:2], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=0, + ) + dec_outs_step_1, _ = self.transformer_decoder( + self.tgt[2:3], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=1, + ) + # randomness might cause failing (seed is set to avoid that) + # small differences are expected due to masking with huge negative + # float but not infinite + self.assertTrue(dec_outs_step_1.allclose(dec_outs[1:])) + + +if __name__ == "__main__": + unittest.main() diff --git a/onmt/tests/test_lm_transformer_decoder.py b/onmt/tests/test_lm_transformer_decoder.py new file mode 100644 index 0000000000..a8794006cc --- /dev/null +++ b/onmt/tests/test_lm_transformer_decoder.py @@ -0,0 +1,70 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch + +from onmt.decoders.transformer import TransformerLMDecoder +from onmt.modules import Embeddings +from onmt.modules.position_ffn import ActivationFunction + + +class TestLMTransformerDecoder(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(42) + emb = Embeddings( + word_vec_size=100, + position_encoding=True, + feat_merge="concat", + feat_vec_exponent=0.7, + feat_vec_size=-1, + dropout=0, + word_padding_idx=1, + feat_padding_idx=[], + word_vocab_size=100, + feat_vocab_sizes=[], + sparse=False, + freeze_word_vecs=False, + ) + cls.lm_transformer_decoder = TransformerLMDecoder( + num_layers=2, + d_model=100, + heads=2, + d_ff=100, + copy_attn=False, + self_attn_type="scaled-dot", + dropout=0, + attention_dropout=0, + embeddings=emb, + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=None, + alignment_layer=None, + alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, + ) + cls.tgt = torch.randint(3, 99, [12, 3, 1]) + cls.lm_transformer_decoder.init_state(None, None, None) + + def test_lm_transformer_caching_equals_no_caching( + self, + ): + dec_outs, _ = self.lm_transformer_decoder( + self.tgt[1:3], None, memory_lengths=None, step=None + ) + dec_outs_step_0, _ = self.lm_transformer_decoder( + self.tgt[1:2], None, memory_lengths=None, step=0 + ) + dec_outs_step_1, _ = self.lm_transformer_decoder( + self.tgt[2:3], None, memory_lengths=None, step=1 + ) + + # randomness might cause failing (seed is set to avoid that) + # small differences are expected due to masking with huge negative + # float but not infinite + self.assertTrue(dec_outs_step_1.allclose(dec_outs[1:])) + + +if __name__ == "__main__": + unittest.main() diff --git a/onmt/tests/test_psa_transformer_decoder.py b/onmt/tests/test_psa_transformer_decoder.py new file mode 100644 index 0000000000..10c093936a --- /dev/null +++ b/onmt/tests/test_psa_transformer_decoder.py @@ -0,0 +1,87 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch + +from onmt.decoders.transformer import TransformerLMPseudoSelfAttentionDecoder +from onmt.modules import Embeddings +from onmt.modules.position_ffn import ActivationFunction + + +class TestPSADecoder(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(42) + emb = Embeddings( + word_vec_size=100, + position_encoding=True, + feat_merge="concat", + feat_vec_exponent=0.7, + feat_vec_size=-1, + dropout=0, + word_padding_idx=1, + feat_padding_idx=[], + word_vocab_size=100, + feat_vocab_sizes=[], + sparse=False, + freeze_word_vecs=False, + ) + cls.psa_transformer_decoder = TransformerLMPseudoSelfAttentionDecoder( + num_layers=2, + d_model=100, + heads=2, + d_ff=100, + copy_attn=False, + self_attn_type="scaled-dot", + dropout=0, + attention_dropout=0, + embeddings=emb, + max_relative_positions=0, + aan_useffn=False, + full_context_alignment=None, + alignment_layer=None, + alignment_heads=None, + pos_ffn_activation_fn=ActivationFunction.relu, + ) + batch_size = 3 + src_len = 58 + tgt_len = 12 + cls.memory_bank = torch.rand([src_len, batch_size, 100]) + cls.tgt = torch.randint(3, 99, [tgt_len, batch_size, 1]) + cls.src = torch.randint(3, 99, [src_len, batch_size, 1]) + cls.memory_lengths = torch.tensor([src_len] * batch_size) + cls.memory_lengths[0] -= 3 + cls.psa_transformer_decoder.init_state( + cls.src, cls.memory_bank, cls.memory_bank + ) + + def test_psa_transformer_caching_equals_no_caching( + self, + ): + dec_outs, _ = self.psa_transformer_decoder( + self.tgt[1:3], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=None, + ) + dec_outs_step_0, _ = self.psa_transformer_decoder( + self.tgt[1:2], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=0, + ) + dec_outs_step_1, _ = self.psa_transformer_decoder( + self.tgt[2:3], + self.memory_bank, + memory_lengths=self.memory_lengths, + step=1, + ) + # randomness might cause failing (seed is set to avoid that) + # small differences are expected due to masking with huge negative + # float but not infinite + self.assertTrue(dec_outs_step_1.allclose(dec_outs[1:])) + + +if __name__ == "__main__": + unittest.main() diff --git a/onmt/tests/test_pseudo_self_attention.py b/onmt/tests/test_pseudo_self_attention.py new file mode 100644 index 0000000000..0aab256119 --- /dev/null +++ b/onmt/tests/test_pseudo_self_attention.py @@ -0,0 +1,125 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch + +from onmt.modules import ( + MultiHeadedAttention, + MultiHeadedPseudoSelfAttention, +) +from onmt.utils.misc import sequence_mask +from onmt.decoders.transformer import TransformerDecoderLayerBase + + +class TestMultiHeadedPseudoSelfAttention(unittest.TestCase): + @classmethod + def setUpClass(cls): + max_relative_positions = 0 + heads = 2 + cls.d_model = 16 + cls.pseudo_self_attention = MultiHeadedPseudoSelfAttention( + heads, + cls.d_model, + dropout=0, + max_relative_positions=max_relative_positions, + ) + cls.self_attention = MultiHeadedAttention( + heads, + cls.d_model, + dropout=0, + max_relative_positions=max_relative_positions, + ) + torch.nn.init.constant_( + cls.pseudo_self_attention.linear_keys.weight, 1 + ) + torch.nn.init.constant_( + cls.pseudo_self_attention.linear_values.weight, 1 + ) + torch.nn.init.constant_( + cls.pseudo_self_attention.linear_query.weight, 1 + ) + torch.nn.init.constant_(cls.self_attention.linear_keys.weight, 1) + torch.nn.init.constant_(cls.self_attention.linear_values.weight, 1) + torch.nn.init.constant_(cls.self_attention.linear_query.weight, 1) + + torch.nn.init.constant_(cls.pseudo_self_attention.linear_keys.bias, 0) + torch.nn.init.constant_( + cls.pseudo_self_attention.linear_values.bias, 0 + ) + torch.nn.init.constant_(cls.pseudo_self_attention.linear_query.bias, 0) + torch.nn.init.constant_(cls.self_attention.linear_keys.bias, 0) + torch.nn.init.constant_(cls.self_attention.linear_values.bias, 0) + torch.nn.init.constant_(cls.self_attention.linear_query.bias, 0) + + torch.nn.init.constant_( + cls.pseudo_self_attention.final_linear.weight, 1 + ) + torch.nn.init.constant_(cls.pseudo_self_attention.final_linear.bias, 1) + torch.nn.init.constant_(cls.self_attention.final_linear.weight, 1) + torch.nn.init.constant_(cls.self_attention.final_linear.bias, 1) + + def test_pseudo_self_attention_equals_self_attention_without_encoding( + self, + ): + X = torch.zeros( + (3, 5, self.d_model) + ) # (batch_size, seq_len, dim_model) + Y = torch.ones((3, 8, self.d_model)) + + output_self_attn, _ = self.self_attention(Y, Y, Y, attn_type="self") + output_pseudo_self_attn, _ = self.pseudo_self_attention(X, Y) + self.assertTrue(output_self_attn.equal(output_pseudo_self_attn)) + + def test_masked_pseudo_self_attention_equals_premasked_encoder(self): + X = 0.3 * torch.ones( + (4, 5, self.d_model) + ) # (batch_size, seq_len, dim_model) + X[0, 4:, :] = 1000 + X[1, 3:, :] = 1000 + + X_premasked = 0.3 * torch.ones( + (4, 5, self.d_model) + ) # (batch_size, seq_len, dim_model) + X_premasked[0, 4:, :] = 0 + X_premasked[1, 3:, :] = 0 + + Y = torch.ones((4, 8, self.d_model)) + + src_pad_mask = ~sequence_mask(torch.tensor([4, 3, 1, 5]), 5).unsqueeze( + 1 + ) + no_mask_src_pad_mask = ~sequence_mask( + torch.tensor([5, 5, 5, 5]), 5 + ).unsqueeze(1) + tgt_pad_mask = ~sequence_mask(torch.tensor([8, 3, 8, 1]), 8).unsqueeze( + 1 + ) + + dec_mask = TransformerDecoderLayerBase._compute_dec_mask( + tgt_pad_mask, future=False + ) + + pseudo_mask = torch.cat( + [src_pad_mask.repeat(1, dec_mask.size(-1), 1), dec_mask], axis=-1 + ) + no_mask_pseudo_mask = torch.cat( + [no_mask_src_pad_mask.repeat(1, dec_mask.size(-1), 1), dec_mask], + axis=-1, + ) + + output, _ = self.pseudo_self_attention( + X, + Y, + mask=pseudo_mask, + attn_type="self", + ) + + output_masked, _ = self.pseudo_self_attention( + X_premasked, + Y, + mask=no_mask_pseudo_mask, + attn_type="self", + ) + + self.assertTrue(output.equal(output_masked))