Skip to content

[ENH] Add Fast Geometric Ensembling #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Changelog
Ver 0.1.*
---------

* |Feature| |API| Add :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
* |Enhancement| Add flexible instantiation of optimizers and schedulers | `@cspsampedro <https://github.com/cspsampedro>`__
* |Feature| |API| Add support on accepting instantiated base estimators as valid input | `@xuyxu <https://github.com/xuyxu>`__
* |Fix| Fix missing base estimators when calling :meth:`load()` for all ensembles | `@xuyxu <https://github.com/xuyxu>`__
Expand Down
39 changes: 39 additions & 0 deletions docs/parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,42 @@ AdversarialTrainingRegressor

.. autoclass:: torchensemble.adversarial_training.AdversarialTrainingRegressor
:members:

Fast Geometric Ensemble
-----------------------

Motivated by geometric insights on the loss surface of deep neural networks,
Fast Geometirc Ensembling (FGE) is an efficient ensemble that uses a
customized learning rate scheduler to generate base estimators, similar to
snapshot ensemble.

Reference:
T. Garipov, P. Izmailov, D. Podoprikhin et al., Loss Surfaces, Mode
Connectivity, and Fast Ensembling of DNNs, NeurIPS, 2018.

Notice that unlike all ensembles above, using fast geometric ensemble (FGE) is
**a two-staged process**. Concretely, you first need to call :meth:`fit` to
build a dummy base estimator that will be used to generate ensembles. Second,
you need to call :meth:`ensemble` to generate real base estimators in the
ensemble. The pipeline is shown in the following code snippet:

.. code:: python

model = FastGeometricClassifier(**ensemble_related_args)
estimator = model.fit(train_loader, **base_estimator_related_args) # train the base estimator
model.ensemble(estimator, train_loader, **fge_related_args) # generate the ensemble using the base estimator

You can refer to scripts in `examples <https://github.com/xuyxu/Ensemble-Pytorch/tree/master/examples>`__ for
a detailed example.

FastGeometricClassifier
***********************

.. autoclass:: torchensemble.fast_geometric.FastGeometricClassifier
:members:

FastGeometricRegressor
***********************

.. autoclass:: torchensemble.fast_geometric.FastGeometricRegressor
:members:
1 change: 1 addition & 0 deletions examples/classification_cifar10_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchensemble.bagging import BaggingClassifier
from torchensemble.gradient_boosting import GradientBoostingClassifier
from torchensemble.snapshot_ensemble import SnapshotEnsembleClassifier
from torchensemble.fast_geometric import FastGeometricClassifier

from torchensemble.utils.logging import set_logger

Expand Down
172 changes: 172 additions & 0 deletions examples/fast_geometric_ensemble_cifar10_resnet18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from torchensemble import FastGeometricClassifier
from torchensemble.utils.logging import set_logger


# The class `BasicBlock` and `ResNet` is modified from:
# https://github.com/kuangliu/pytorch-cifar
class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(self.expansion * planes),
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64

self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out


if __name__ == "__main__":

# Hyper-parameters
n_estimators = 10
lr = 1e-1
weight_decay = 5e-4
momentum = 0.9
epochs = 200

# Utils
batch_size = 128
data_dir = "../../Dataset/cifar" # MODIFY THIS IF YOU WANT
torch.manual_seed(0)
torch.cuda.set_device(0)

# Load data
train_transformer = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)

test_transformer = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)

train_loader = DataLoader(
datasets.CIFAR10(
data_dir, train=True, download=True, transform=train_transformer
),
batch_size=batch_size,
shuffle=True,
)

test_loader = DataLoader(
datasets.CIFAR10(data_dir, train=False, transform=test_transformer),
batch_size=batch_size,
shuffle=True,
)

# Set the Logger
logger = set_logger("FastGeometricClassifier_cifar10_resnet")

# Choose the Ensemble Method
model = FastGeometricClassifier(
estimator=ResNet,
estimator_args={"block": BasicBlock, "num_blocks": [2, 2, 2, 2]},
n_estimators=n_estimators,
cuda=True,
)

# Set the Optimizer
model.set_optimizer(
"SGD", lr=lr, weight_decay=weight_decay, momentum=momentum
)

# Set the Scheduler
model.set_scheduler("CosineAnnealingLR", T_max=epochs)

# Train
estimator = model.fit(train_loader, epochs=epochs, test_loader=test_loader)

# Ensemble
model.ensemble(
estimator,
train_loader,
epochs=40,
lr_1=5e-2,
lr_2=5e-4,
test_loader=test_loader,
)

# Evaluate
acc = model.predict(test_loader)
print("Testing Acc: {:.3f}".format(acc))
4 changes: 4 additions & 0 deletions torchensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .snapshot_ensemble import SnapshotEnsembleRegressor
from .adversarial_training import AdversarialTrainingClassifier
from .adversarial_training import AdversarialTrainingRegressor
from .fast_geometric import FastGeometricClassifier
from .fast_geometric import FastGeometricRegressor


__all__ = [
Expand All @@ -25,4 +27,6 @@
"SnapshotEnsembleRegressor",
"AdversarialTrainingClassifier",
"AdversarialTrainingRegressor",
"FastGeometricClassifier",
"FastGeometricRegressor",
]
1 change: 1 addition & 0 deletions torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_doc(item):
"""Return the selected item."""
__doc = {
"model": const.__model_doc,
"seq_model": const.__seq_model_doc,
"fit": const.__fit_doc,
"set_optimizer": const.__set_optimizer_doc,
"set_scheduler": const.__set_scheduler_doc,
Expand Down
27 changes: 27 additions & 0 deletions torchensemble/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@
"""


__seq_model_doc = """
Parameters
----------
estimator : torch.nn.Module
The class or object of your base estimator.

- If :obj:`class`, it should inherit from :mod:`torch.nn.Module`.
- If :obj:`object`, it should be instantiated from a class inherited
from :mod:`torch.nn.Module`.
n_estimators : int
The number of base estimators in the ensemble.
estimator_args : dict, default=None
The dictionary of hyper-parameters used to instantiate base
estimators. This parameter will have no effect if ``estimator`` is a
base estimator object after instantiation.
cuda : bool, default=True

- If ``True``, use GPU to train and evaluate the ensemble.
- If ``False``, use CPU to train and evaluate the ensemble.

Attributes
----------
estimators_ : torch.nn.ModuleList
An internal container that stores all fitted base estimators.
"""


__set_optimizer_doc = """
Parameters
----------
Expand Down
Loading