Skip to content

Commit 1403ee6

Browse files
committed
first commit
0 parents  commit 1403ee6

43 files changed

Lines changed: 63953 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# project specific
2+
data/
3+
saves/
4+
logs/
5+
6+
# Byte-compiled / optimized / DLL files
7+
__pycache__/
8+
*.py[cod]
9+
*$py.class
10+
11+
# C extensions
12+
*.so
13+
14+
# Distribution / packaging
15+
.Python
16+
build/
17+
develop-eggs/
18+
dist/
19+
downloads/
20+
eggs/
21+
.eggs/
22+
lib/
23+
lib64/
24+
parts/
25+
sdist/
26+
var/
27+
wheels/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
MANIFEST
32+
33+
# PyInstaller
34+
# Usually these files are written by a python script from a template
35+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
36+
*.manifest
37+
*.spec
38+
39+
# Installer logs
40+
pip-log.txt
41+
pip-delete-this-directory.txt
42+
43+
# Unit test / coverage reports
44+
htmlcov/
45+
.tox/
46+
.coverage
47+
.coverage.*
48+
.cache
49+
nosetests.xml
50+
coverage.xml
51+
*.cover
52+
.hypothesis/
53+
.pytest_cache/
54+
55+
# Translations
56+
*.mo
57+
*.pot
58+
59+
# Django stuff:
60+
*.log
61+
local_settings.py
62+
db.sqlite3
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# pyenv
81+
.python-version
82+
83+
# celery beat schedule file
84+
celerybeat-schedule
85+
86+
# SageMath parsed files
87+
*.sage.py
88+
89+
# Environments
90+
.env
91+
.venv
92+
env/
93+
venv/
94+
ENV/
95+
env.bak/
96+
venv.bak/
97+
98+
# Spyder project settings
99+
.spyderproject
100+
.spyproject
101+
102+
# Rope project settings
103+
.ropeproject
104+
105+
# mkdocs documentation
106+
/site
107+
108+
# mypy
109+
.mypy_cache/
110+
111+
*.txt
112+
*.sw[opn]
113+
*.pt
114+
*.hdf5
115+
train_dir/
116+
*tgz
117+
*tar.gz

LICENSE

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
MIT License
2+
3+
Copyright (c) 2018 Tristan Deleu
4+
Copyright (c) 2018 Risto Vuorio
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.

README.md

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Multimodal Model-Agnostic Meta-Learning for Few-shot Classification
2+
3+
This project is an implementation of [Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation](https://arxiv.org/abs/1910.13616), which is published in [NeurIPS 2019](https://neurips.cc/Conferences/2019/).
4+
5+
Model-agnostic meta-learners aim to acquire meta-prior parameters from a distribution of tasks and adapt to novel tasks with few gradient updates. Yet, seeking a common initialization shared across the entire task distribution substantially limits the diversity of the task distributions that they are able to learn from. We propose a multimodal MAML (MMAML) framework, which is able to modulate its meta-learned prior according to the identified mode, allowing more efficient fast adaptation. An illustration of the proposed framework is as follows.
6+
7+
<p align="center">
8+
<img src="asset/model.png" width="480"/>
9+
</p>
10+
11+
## Datasets
12+
13+
Run the following command to download and preprocess the datasets
14+
15+
```bash
16+
python download.py --dataset aircraft bird cifar miniimagenet
17+
```
18+
19+
## Getting started
20+
To avoid any conflict with your existing Python setup, and to keep this project self-contained, it is suggested to work in a virtual environment with [`virtualenv`](http://docs.python-guide.org/en/latest/dev/virtualenvs/). To install `virtualenv`:
21+
```
22+
pip install --upgrade virtualenv
23+
```
24+
Create a virtual environment, activate it and install the requirements in [`requirements.txt`](requirements.txt).
25+
```
26+
virtualenv mmaml_venv
27+
source mmaml_venv/bin/activate
28+
pip install -r requirements.txt
29+
```
30+
31+
## Training commands
32+
33+
### Train
34+
35+
```bash
36+
$ python main.py -dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s
37+
```
38+
- Selected arguments (see the `trainer.py` for more details)
39+
- --output-folder: a nickname for the training
40+
- --dataset: choose among `omniglot`, `miniimagenet`, `cifar`, `bird` (CUB), and `aircraft`. You can also add your own datasets.
41+
- Checkpoints: specify the path to a pre-trained checkpoint
42+
- --checkpoint: load all the parameters (e.g. `train_dir/mmaml_5mode_5w1s/maml_gatedconv_60000.pt`).
43+
- Hyperparameters
44+
- --num-batches: number of batches
45+
- --meta-batch-size: number of tasks per batch
46+
- --slow-lr: learning rate for the global update of MAML
47+
- --fast-lr: learning rate for the adapted models
48+
- --num-updates: how many update steps in the inner loop
49+
- --num-classes-per-batch: how many classes per task (`N`-way)
50+
- --num-samples-per-class: how many samples per class for training (`K`-shot)
51+
- --num-val-samples: how many samples per class for validation
52+
- --max\_steps: the max training iterations
53+
- Logging
54+
- --log-interval: number of batches between tensorboard writes
55+
- --save-interval: number of batches between model saves
56+
- Model
57+
- maml-model: set to `True` to train a MAML model
58+
- mmaml-model: set to `True` to train a MMAML (our) model
59+
60+
### 2 Modes (Omniglot and Mini-ImageNet)
61+
62+
| Setup | Method | Command |
63+
| :-----: | :------: | ---------------------------------------- |
64+
| 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --output-folder maml_2mode_5w1s``` |
65+
| 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --output-folder mmaml_2mode_5w1s``` |
66+
| 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_2mode_5w5s``` |
67+
| 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_2mode_5w5s``` |
68+
| 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_2mode_20w1s``` |
69+
| 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_2mode_20w1s``` |
70+
71+
### 3 Modes (Omniglot, Mini-ImageNet, and FC100)
72+
73+
| Setup | Method | Command |
74+
| :-----: | :------: | ---------------------------------------- |
75+
| 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --output-folder maml_3mode_5w1s``` |
76+
| 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --output-folder mmaml_3mode_5w1s``` |
77+
| 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` |
78+
| 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` |
79+
| 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_3mode_20w1s``` |
80+
| 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_3mode_20w1s``` |
81+
82+
### 5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB)
83+
84+
| Setup | Method | Command |
85+
| :-----: | :------: | ---------------------------------------- |
86+
| 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s``` |
87+
| 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s``` |
88+
| 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s``` |
89+
| 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s``` |
90+
| 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` |
91+
| 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` |
92+
| 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` |
93+
| 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` |
94+
| 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s``` |
95+
| 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s``` |
96+
| 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s``` |
97+
| 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s``` |
98+
99+
### Interpret TensorBoard
100+
Launch Tensorboard and go to the specified port, you can see differernt losses in the **scalars** tab.
101+
102+
103+
### Multi-MAML
104+
105+
| Setup | Dataset | Command |
106+
| :-----: | :------: | ---------------------------------------- |
107+
| 5w1s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --output-folder multi_omniglot_5w1s``` |
108+
| 5w1s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_miniimagenet_5w1s``` |
109+
| 5w1s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_cifar_5w1s``` |
110+
| 5w1s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_bird_5w1s```|
111+
| 5w1s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_aircraft_5w1s```|
112+
| 5w5s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --num-samples-per-class 5 --output-folder multi_omniglot_5w5s``` |
113+
| 5w5s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_miniimagenet_5w5s``` |
114+
| 5w5s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_cifar_5w5s``` |
115+
| 5w5s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_bird_5w5s```|
116+
| 5w5s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_aircraft_5w5s```|
117+
| 20w1s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.1 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_omniglot_20w1s``` |
118+
| 20w1s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_miniimagenet_20w1s``` |
119+
| 20w1s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_cifar_20w1s``` |
120+
| 20w1s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_bird_20w1s``` |
121+
| 20w1s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_aircraft_20w1s``` |
122+
123+
## Results
124+
125+
### 2 Modes (Omniglot and Mini-ImageNet)
126+
127+
| Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot |
128+
| :----------: | :----------: | :----------: | :----------: |
129+
| MAML | 66.80% | 77.79% | 44.69% |
130+
| Multi-MAML | 66.85% | 73.07% | 53.15% |
131+
| MMAML (Ours) | 69.93% | 78.73% | 47.80% |
132+
133+
### 3 Modes (Omniglot, Mini-ImageNet, and FC100)
134+
135+
| Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot |
136+
| :----------: | :----------: | :----------: | :----------: |
137+
| MAML | 54.55% | 67.97% | 28.22% |
138+
| Multi-MAML | 55.90% | 62.20% | 39.77% |
139+
| MMAML (Ours) | 57.47% | 70.15% | 36.27% |
140+
141+
### 5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB)
142+
143+
| Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot |
144+
| :----------: | :----------: | :----------: | :----------: |
145+
| MAML | 44.09% | 54.41% | 28.85% |
146+
| Multi-MAML | 45.46% | 55.92% | 33.78% |
147+
| MMAML (Ours) | 49.06% | 60.83% | 33.97% |
148+
149+
Please check out our paper for more comprehensive results.
150+
151+
## Related work
152+
- \[MAML\] [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) in ICML 2017
153+
154+
## Cite the paper
155+
If you find this useful, please cite
156+
```
157+
@inproceedings{vuorio2019multimodal,
158+
title={Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation},
159+
author={Vuorio, Risto and Sun, Shao-Hua and Hu, Hexiang and Lim, Joseph J.},
160+
booktitle={Neural Information Processing Systems},
161+
year={2019},
162+
}
163+
```
164+
165+
## Author
166+
[Shao-Hua Sun](http://shaohua0116.github.io/)

download.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import subprocess
2+
import argparse
3+
4+
5+
parser = argparse.ArgumentParser(description='Download datasets for MMAML.')
6+
parser.add_argument('--dataset', metavar='N', type=str, nargs='+',
7+
choices=['aircraft', 'bird', 'cifar', 'miniimagenet'])
8+
9+
10+
def download(dataset):
11+
cmd = ['python', 'get_dataset_script/get_{}.py'.format(dataset)]
12+
print(' '.join(cmd))
13+
subprocess.call(cmd)
14+
return
15+
16+
17+
if __name__ == '__main__':
18+
args = parser.parse_args()
19+
if len(args.dataset) > 0:
20+
for dataset in args.dataset:
21+
download(dataset)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# https://drive.google.com/drive/folders/0B-r7apOz1BHAWXYwT1lGb3J1Yjg
2+
# Download files on Google Drive
3+
import requests
4+
5+
def download_file_from_google_drive(id, destination):
6+
URL = "https://drive.google.com/uc?export=download"
7+
8+
session = requests.Session()
9+
10+
response = session.get(URL, params = { 'id' : id }, stream = True)
11+
token = get_confirm_token(response)
12+
13+
if token:
14+
params = { 'id' : id, 'confirm' : token }
15+
response = session.get(URL, params = params, stream = True)
16+
17+
save_response_content(response, destination)
18+
19+
def get_confirm_token(response):
20+
for key, value in response.cookies.items():
21+
if key.startswith('download_warning'):
22+
return value
23+
24+
return None
25+
26+
def save_response_content(response, destination):
27+
CHUNK_SIZE = 32768
28+
29+
with open(destination, "wb") as f:
30+
for chunk in response.iter_content(CHUNK_SIZE):
31+
if chunk: # filter out keep-alive new chunks
32+
f.write(chunk)
33+
34+
if __name__ == "__main__":
35+
file_id = "1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk"
36+
destination = './mini-imagenet.zip'
37+
download_file_from_google_drive(file_id, destination)

get_dataset_script/get_aircraft.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import subprocess
2+
3+
4+
cmds = []
5+
cmds.append(['wget', 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'])
6+
cmds.append(['tar', 'xvzf', 'fgvc-aircraft-2013b.tar.gz'])
7+
cmds.append(['python', 'get_dataset_script/proc_aircraft.py'])
8+
cmds.append(['rm', '-rf', 'fgvc-aircraft-2013b.tar.gz', 'fgvc-aircraft-2013b'])
9+
10+
for cmd in cmds:
11+
print(' '.join(cmd))
12+
subprocess.call(cmd)

get_dataset_script/get_bird.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import subprocess
2+
import os
3+
4+
5+
cmds = []
6+
cmds.append(['wget', 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'])
7+
cmds.append(['tar', 'xvzf', 'CUB_200_2011.tgz'])
8+
cmds.append(['python', 'get_dataset_script/proc_bird.py'])
9+
cmds.append(['rm', '-rf', 'CUB_200_2011', 'CUB_200_2011.tgz'])
10+
11+
for cmd in cmds:
12+
print(' '.join(cmd))
13+
subprocess.call(cmd)

0 commit comments

Comments
 (0)