Skip to content

Conversation

le1nux
Copy link
Member

@le1nux le1nux commented Mar 4, 2024

The layer norm was originally instantiated individually for every attention block internally.

self.ln_1 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon)

For every new layer norm type we would have had to add an if-clause to check which layer norm we would want to instantiate. As a workaround, we now pass in the layer norm object from outside to the GPT2 model and copy it in every attention block. Note that we override the copy function in the layer norm implementations.

For the future, it would make sense to have the possibility to instantiate Lists of components. For instance, a GPTModel would have a dependency for a list of attention block. We would specify a single attention block and instantiate the block n times (see num_instances in the YAML below). Each attention block would have a dependency for a layer norm and would not have to be copied internally anymore.

This is an example:

model:
  component_key: model
  variant_key: gpt2
  config:
    [...]
    attention_blocks:
        component_key: attention_block
        variant_key: gpt2_attention_block
        num_instances: 12 
        config:    
          n_embd: 768
          dropout: 0.0
          scaling_factor: 3
         [...]

@le1nux le1nux added the enhancement New feature or request label Mar 4, 2024
@le1nux le1nux self-assigned this Mar 7, 2024
@le1nux le1nux requested review from mali-git and flxst March 7, 2024 17:27
@le1nux le1nux marked this pull request as ready for review March 7, 2024 17:28
@le1nux le1nux changed the title Draft: Rms norm implementation RMS norm implementation Mar 7, 2024
@@ -31,7 +31,7 @@ train_dataset:
component_key: dataset
variant_key: packed_mem_map_dataset_megatron
config:
raw_data_path: /raid/s3/opengptx/max_lue/LLMgym/data/redpyjama_v2_default_DE_num_docs_16777216.pbin
raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_1050391.pbin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably use relative paths here (and in other configs, too).



class RMSLayerNorm(LayerNormIF):
def __init__(self, ndim: int, epsilon: float = 1e-6):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not implement an (optional) bias for RMSLayerNorm, just like we do for ZLayerNorm? The original RMSNorm paper uses a bias by default.

Copy link
Member Author

@le1nux le1nux Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I also checked the original RMSNorm implementation and they also had it (see: https://github.com/bzhangGo/rmsnorm/blob/2e726f1a3f106bb719056422f4f9b6aca03d3ce6/rmsnorm_torch.py#L32). Added bias also to this implementation.


Args:
ndim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Epsilon is 1e-6 in the LLaMa implementation by default. However, it seems that they actually used 1e-5 themselves, see here. 1e-5 is also the default value in PyTorch for LayerNorm and used elsewhere for RMSNorm (e.g. here), so it seems like a standard value that we should perhaps also use instead of 1e-6?

return copied_instance


class ZLayerNorm(LayerNormIF):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it called ZLayerNorm, i.e. what does the Z stand for? Is this only to differentiate it from the more generic LayerNormIF class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layer norm that is implemented in Pytorch basically calculates the z scores for each vector component (with two additional, learnable affine transformation parameters):
https://en.wikipedia.org/wiki/Standard_score

I found that the naming of layer norm is too generic as RMSNorm is also a layer norm. What naming would you suggest?

Copy link
Member Author

@le1nux le1nux Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Michael and Mehdi suggested to also use the original Layer Norm name. Given the overrulement it's LayerNorm again :-)
Also, we don't use a custom LayerNorm wrapper anymore. I found a way to simplify that part so that we don't have to override __copy__().


class RMSLayerNormConfig(BaseModel):
ndim: Annotated[int, Field(strict=True, ge=1)]
epsilon: Annotated[float, Field(gt=0, default=1e-6)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we do not use strict=True here (and above in ZLayerNormConfig)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

Copy link
Member

@flxst flxst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! I added a few comments and questions.

@lllAlexanderlll lllAlexanderlll self-requested a review March 11, 2024 08:48
le1nux added 4 commits March 13, 2024 16:43
…orch implementation without the need for a wrapper. Removed __copy__ overrides, as calling deepcopy on nn.Module already is capable of recursively copying a nn.Module. Introduced bias to RMSLayerNorm
@le1nux le1nux merged commit 4f509cc into main Mar 13, 2024
@le1nux le1nux deleted the rms_norm branch March 13, 2024 17:53
@le1nux le1nux restored the rms_norm branch March 13, 2024 17:54
le1nux added a commit that referenced this pull request Mar 13, 2024
le1nux added a commit that referenced this pull request Mar 13, 2024
le1nux added a commit that referenced this pull request Mar 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants