Skip to content

[AAAI 2025] How Do Position Encodings Affect Length Generalization? Case Studies on In-Context Function Learning

License

Notifications You must be signed in to change notification settings

IKMLab/length-generalization-with-icl

Repository files navigation

How Do Position Encodings Affect Length Generalization? Case Studies on In-Context Function Learning

This repository contains code for the paper: How Do Position Encodings Affect Length Generalization? Case Studies on In-Context Function Learning

Our codebase is based on the following repositories:

Repo structure:

┌── requirements.txt
├── configs
│   ├── base_model/     # Different size of backbone models
│   ├── hyperparams/    # Settings for culriculum learning and training
│   ├── logging/        # Wandb logging settings
│   ├── patch/          # Settings for selecting position encoding
│   ├── tasks/          # 
│   └── toy_task.yaml   # For testing environment
│
├── models              # Will be generated after training
│   └── task_name
│       └── input_dim
│           └── run_date
│               └── uid
│
├── eval_atk.py
├── eval_atk.sh         # Refer to README.md $3.3
├── eval_attn.py
├── eval_attn.sh        # Refer to README.md $3.2
├── eval.sh             # Refer to README.md $3.1
├── train.py
├── train.sh            # Refer to README.md $2
├── .gitignore
└── README.md

0. Environment

Hardware and Driver

  • Ubuntu 22.04
  • RTX 3060 12GB
  • CUDA 12.1

Dependency

  • Python 3.11.*
  • Pytorch 2.3.1 + CUDA 12.1
  • We use pipreqs for generating requirements.txt. Please create a new environment and install all packages.

1. Data Generation

SOURCE CODE: src/samplers.py

Since our method needs fresh data for each training step, we do not have a fixed dataset. We generate data on-the-fly during training.

You can refer to the training file train.py for the data generation process.

# Generate data
xs = task.sample_xs(
    self.curriculum.n_points,
    self.batch_size,
    self.curriculum.n_dims_truncated,
)
ys = task.evaluate(xs)

2. Training

Edit Config Files

All of our config files are under 'configs/' folder. Here is an example of training for Linear Regression with ALiBi.

model:
  !include ../../base_model/llama/standard.yaml

patch:
  !include ../../patch/alibi/alibi.yaml

task:
  data: gaussian
  name: linear_regression
  curriculum: 
    !include ../../hyperparams/curriculum/dim-20_pts-50.yaml

training:
  !include ../../hyperparams/training_args.yaml

wandb:
  !include ../../logging/wandb.yaml

Details of each line:

  • model: choose the backbone model. We only report the results of Llama in the paper.
  • patch: choose the position encoding.
  • task: please refer other config files for different tasks.
  • training: our training hyperparameters.
  • wandb: logging configuration. Default is offline so you can ignore this.

For the default settings, just simply run the training script:

bash train.sh

Our pipeline will save the checkpoint (state.pt) and config file (config.yaml) at models/$task_name$/$input_dim$/$run_date$/uid

If you want to modify train.sh for other tasks, here is all options.

run_task=("linear_regression" "sparse_linear_regression" "cnf" "conjunction" "disjunction" "dnf" "int_halfspace" "majority" "parity" "sparse_disjunction" "sparse_parity" "sparse_thres")
run_pe=("alibi" "dynamic_yarn" "fire" "nope" "rope" "yarn" "mamba")

3. Evaluation and Analysis

3.1 Main Experiments from paper

Normally, training pipeline will also evaluate the model. If you want to re-evaluate specific model, please modify the eval.sh script and run:

bash eval.sh

The output files will be stored at models/$task_name$/$input_dim$/$run_date$/uid/

Example: ood_length.png

main

3.2 Discussion

Does Increasing Model Size Enhance Performance?

Please change the model field in the config file to small or large and run the training script.

model:
  !include ../../base_model/llama/$model_size$.yaml # small.yaml or large.yaml

Can recency bias explain why Transformer fails?

Please make sure that the 'eval_date' field in bash file is correct and run:

bash eval_attn.sh

The output files will be stored at models/$task_name$/$input_dim$/$run_date$/uid/duplicated

Example: avg_attention_score.png

main

Does inductive bias exist in In-Context Function Learning?

Please make sure that the 'eval_date' field in bash file is correct and run:

bash eval_atk.sh

The output files will be stored at models/$task_name$/$input_dim$/$run_date$/uid/duplicated

Example: baseline_all_dup.png

main

Does serial-position effect exist in In-Context Function Learning?

Please make sure that the 'eval_date' field in bash file is correct and run:

bash eval_atk.sh

The output files will be stored at models/$task_name$/$input_dim$/$run_date$/uid/duplicated

Example: lost_in_mid.png

main

Will State Space Model Generalize on These Noise?

Our analysis experiment is based on Mamba architecture. Make sure all config files are modified the model field to mamba and run the training script.

model:
  !include ../../base_model/mamba/standard.yaml
run_pe=("mamba")

Rerun all the evaluation process aforementioned, except for the attention related analysis.

Citation

About

[AAAI 2025] How Do Position Encodings Affect Length Generalization? Case Studies on In-Context Function Learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published