|
| 1 | +# Multimodal Model-Agnostic Meta-Learning for Few-shot Classification |
| 2 | + |
| 3 | +This project is an implementation of [Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation](https://arxiv.org/abs/1910.13616), which is published in [NeurIPS 2019](https://neurips.cc/Conferences/2019/). |
| 4 | + |
| 5 | +Model-agnostic meta-learners aim to acquire meta-prior parameters from a distribution of tasks and adapt to novel tasks with few gradient updates. Yet, seeking a common initialization shared across the entire task distribution substantially limits the diversity of the task distributions that they are able to learn from. We propose a multimodal MAML (MMAML) framework, which is able to modulate its meta-learned prior according to the identified mode, allowing more efficient fast adaptation. An illustration of the proposed framework is as follows. |
| 6 | + |
| 7 | +<p align="center"> |
| 8 | + <img src="asset/model.png" width="480"/> |
| 9 | +</p> |
| 10 | + |
| 11 | +## Datasets |
| 12 | + |
| 13 | +Run the following command to download and preprocess the datasets |
| 14 | + |
| 15 | +```bash |
| 16 | +python download.py --dataset aircraft bird cifar miniimagenet |
| 17 | +``` |
| 18 | + |
| 19 | +## Getting started |
| 20 | +To avoid any conflict with your existing Python setup, and to keep this project self-contained, it is suggested to work in a virtual environment with [`virtualenv`](http://docs.python-guide.org/en/latest/dev/virtualenvs/). To install `virtualenv`: |
| 21 | +``` |
| 22 | +pip install --upgrade virtualenv |
| 23 | +``` |
| 24 | +Create a virtual environment, activate it and install the requirements in [`requirements.txt`](requirements.txt). |
| 25 | +``` |
| 26 | +virtualenv mmaml_venv |
| 27 | +source mmaml_venv/bin/activate |
| 28 | +pip install -r requirements.txt |
| 29 | +``` |
| 30 | + |
| 31 | +## Training commands |
| 32 | + |
| 33 | +### Train |
| 34 | + |
| 35 | +```bash |
| 36 | +$ python main.py -dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s |
| 37 | +``` |
| 38 | +- Selected arguments (see the `trainer.py` for more details) |
| 39 | + - --output-folder: a nickname for the training |
| 40 | + - --dataset: choose among `omniglot`, `miniimagenet`, `cifar`, `bird` (CUB), and `aircraft`. You can also add your own datasets. |
| 41 | + - Checkpoints: specify the path to a pre-trained checkpoint |
| 42 | + - --checkpoint: load all the parameters (e.g. `train_dir/mmaml_5mode_5w1s/maml_gatedconv_60000.pt`). |
| 43 | + - Hyperparameters |
| 44 | + - --num-batches: number of batches |
| 45 | + - --meta-batch-size: number of tasks per batch |
| 46 | + - --slow-lr: learning rate for the global update of MAML |
| 47 | + - --fast-lr: learning rate for the adapted models |
| 48 | + - --num-updates: how many update steps in the inner loop |
| 49 | + - --num-classes-per-batch: how many classes per task (`N`-way) |
| 50 | + - --num-samples-per-class: how many samples per class for training (`K`-shot) |
| 51 | + - --num-val-samples: how many samples per class for validation |
| 52 | + - --max\_steps: the max training iterations |
| 53 | + - Logging |
| 54 | + - --log-interval: number of batches between tensorboard writes |
| 55 | + - --save-interval: number of batches between model saves |
| 56 | + - Model |
| 57 | + - maml-model: set to `True` to train a MAML model |
| 58 | + - mmaml-model: set to `True` to train a MMAML (our) model |
| 59 | + |
| 60 | +### 2 Modes (Omniglot and Mini-ImageNet) |
| 61 | + |
| 62 | +| Setup | Method | Command | |
| 63 | +| :-----: | :------: | ---------------------------------------- | |
| 64 | +| 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --output-folder maml_2mode_5w1s``` | |
| 65 | +| 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --output-folder mmaml_2mode_5w1s``` | |
| 66 | +| 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_2mode_5w5s``` | |
| 67 | +| 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_2mode_5w5s``` | |
| 68 | +| 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_2mode_20w1s``` | |
| 69 | +| 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_2mode_20w1s``` | |
| 70 | + |
| 71 | +### 3 Modes (Omniglot, Mini-ImageNet, and FC100) |
| 72 | + |
| 73 | +| Setup | Method | Command | |
| 74 | +| :-----: | :------: | ---------------------------------------- | |
| 75 | +| 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --output-folder maml_3mode_5w1s``` | |
| 76 | +| 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --output-folder mmaml_3mode_5w1s``` | |
| 77 | +| 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` | |
| 78 | +| 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` | |
| 79 | +| 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_3mode_20w1s``` | |
| 80 | +| 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_3mode_20w1s``` | |
| 81 | + |
| 82 | +### 5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB) |
| 83 | + |
| 84 | +| Setup | Method | Command | |
| 85 | +| :-----: | :------: | ---------------------------------------- | |
| 86 | +| 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s``` | |
| 87 | +| 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s``` | |
| 88 | +| 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s``` | |
| 89 | +| 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s``` | |
| 90 | +| 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` | |
| 91 | +| 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` | |
| 92 | +| 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` | |
| 93 | +| 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` | |
| 94 | +| 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s``` | |
| 95 | +| 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s``` | |
| 96 | +| 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s``` | |
| 97 | +| 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s``` | |
| 98 | + |
| 99 | +### Interpret TensorBoard |
| 100 | +Launch Tensorboard and go to the specified port, you can see differernt losses in the **scalars** tab. |
| 101 | + |
| 102 | + |
| 103 | +### Multi-MAML |
| 104 | + |
| 105 | +| Setup | Dataset | Command | |
| 106 | +| :-----: | :------: | ---------------------------------------- | |
| 107 | +| 5w1s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --output-folder multi_omniglot_5w1s``` | |
| 108 | +| 5w1s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_miniimagenet_5w1s``` | |
| 109 | +| 5w1s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_cifar_5w1s``` | |
| 110 | +| 5w1s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_bird_5w1s```| |
| 111 | +| 5w1s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_aircraft_5w1s```| |
| 112 | +| 5w5s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --num-samples-per-class 5 --output-folder multi_omniglot_5w5s``` | |
| 113 | +| 5w5s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_miniimagenet_5w5s``` | |
| 114 | +| 5w5s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_cifar_5w5s``` | |
| 115 | +| 5w5s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_bird_5w5s```| |
| 116 | +| 5w5s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_aircraft_5w5s```| |
| 117 | +| 20w1s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.1 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_omniglot_20w1s``` | |
| 118 | +| 20w1s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_miniimagenet_20w1s``` | |
| 119 | +| 20w1s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_cifar_20w1s``` | |
| 120 | +| 20w1s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_bird_20w1s``` | |
| 121 | +| 20w1s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_aircraft_20w1s``` | |
| 122 | + |
| 123 | +## Results |
| 124 | + |
| 125 | +### 2 Modes (Omniglot and Mini-ImageNet) |
| 126 | + |
| 127 | +| Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | |
| 128 | +| :----------: | :----------: | :----------: | :----------: | |
| 129 | +| MAML | 66.80% | 77.79% | 44.69% | |
| 130 | +| Multi-MAML | 66.85% | 73.07% | 53.15% | |
| 131 | +| MMAML (Ours) | 69.93% | 78.73% | 47.80% | |
| 132 | + |
| 133 | +### 3 Modes (Omniglot, Mini-ImageNet, and FC100) |
| 134 | + |
| 135 | +| Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | |
| 136 | +| :----------: | :----------: | :----------: | :----------: | |
| 137 | +| MAML | 54.55% | 67.97% | 28.22% | |
| 138 | +| Multi-MAML | 55.90% | 62.20% | 39.77% | |
| 139 | +| MMAML (Ours) | 57.47% | 70.15% | 36.27% | |
| 140 | + |
| 141 | +### 5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB) |
| 142 | + |
| 143 | +| Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | |
| 144 | +| :----------: | :----------: | :----------: | :----------: | |
| 145 | +| MAML | 44.09% | 54.41% | 28.85% | |
| 146 | +| Multi-MAML | 45.46% | 55.92% | 33.78% | |
| 147 | +| MMAML (Ours) | 49.06% | 60.83% | 33.97% | |
| 148 | + |
| 149 | +Please check out our paper for more comprehensive results. |
| 150 | + |
| 151 | +## Related work |
| 152 | +- \[MAML\] [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) in ICML 2017 |
| 153 | + |
| 154 | +## Cite the paper |
| 155 | +If you find this useful, please cite |
| 156 | +``` |
| 157 | +@inproceedings{vuorio2019multimodal, |
| 158 | + title={Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation}, |
| 159 | + author={Vuorio, Risto and Sun, Shao-Hua and Hu, Hexiang and Lim, Joseph J.}, |
| 160 | + booktitle={Neural Information Processing Systems}, |
| 161 | + year={2019}, |
| 162 | +} |
| 163 | +``` |
| 164 | + |
| 165 | +## Author |
| 166 | +[Shao-Hua Sun](http://shaohua0116.github.io/) |
0 commit comments