767 lines
35 KiB
Python
767 lines
35 KiB
Python
import argparse
|
|
import ssl
|
|
from abc import ABC
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as T
|
|
import random
|
|
import warnings
|
|
from tqdm import tqdm
|
|
from typing import List, Tuple
|
|
from os import mkdir, remove
|
|
from os.path import exists
|
|
import matplotlib.pyplot as plt
|
|
|
|
from avalanche.benchmarks.classic import SplitMNIST, SplitCIFAR10, SplitCIFAR100, SplitTinyImageNet, SplitOmniglot
|
|
from avalanche.benchmarks.generators import nc_benchmark
|
|
from avalanche.benchmarks.scenarios.new_classes.nc_scenario import NCExperience, NCScenario
|
|
|
|
paper_name = 'cfa'
|
|
|
|
parser = argparse.ArgumentParser(f'./{paper_name}.py',
|
|
description='Class-Incremental Learning via Knowledge Amalgamation')
|
|
parser.add_argument('--dataset', type=str, default='mnist', help='Dataset to use', required=False,
|
|
choices=['usps', 'mnist', 'cifar10', 'cifar100', 'tiny10', 'tiny20', 'omniglot'])
|
|
parser.add_argument('--seed', type=int, default=None, metavar='N', help='Set a seed to compare runs')
|
|
parser.add_argument('--cuda', action='store_true', help='enable CUDA')
|
|
parser.add_argument('--cuda_device', type=int, default=0, help='Cuda device identifier')
|
|
|
|
# Teacher models configuration
|
|
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
|
|
parser.add_argument('--batch_size', type=int, default=2 ** 6, help='Batch size for base model training',
|
|
choices=[2 ** 3, 2 ** 4, 2 ** 5, 2 ** 6, 2 ** 7, 2 ** 8])
|
|
parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs per task (base learning)',
|
|
choices=[5, 10, 50, 100, 1000])
|
|
parser.add_argument('--force_base_retraining', type=bool, default=True, help='Force base model retraining')
|
|
|
|
# Amalgamation configuration
|
|
parser.add_argument('--cfl_lr', type=float, default=None, help='Common feature amalgamation learning rate')
|
|
parser.add_argument('--amalgamation_strategy', type=str, default='all_together', help='Amalgamation Strategy',
|
|
choices=['all_together', 'one_at_a_time'])
|
|
parser.add_argument('--amalgamation_epochs', type=int, default=100, help='Amalgamation epochs',
|
|
choices=[10, 100, 500, 1000])
|
|
|
|
# Memory strategy configuration
|
|
parser.add_argument('--memory_strategy', type=str, default='grow', help='Memory Strategy',
|
|
choices=['fixed', 'grow'])
|
|
parser.add_argument('--memory_budget', type=int, default=1000, help='Memory Budget',
|
|
choices=[100, 200, 500, 1000, 2000])
|
|
parser.add_argument('--alpha', type=float, default=0.5, help='Alpha')
|
|
|
|
|
|
# Helpers
|
|
def enum(**enums):
|
|
return type('Enum', (), enums)
|
|
|
|
|
|
FIELD = enum(EXT_IDX='last_searched_external_idx',
|
|
INT_IDX='last_searched_internal_idx',
|
|
N_ELEM='n_elem',
|
|
CLASSES_LIST='classes_list')
|
|
|
|
|
|
class AverageTracker():
|
|
FIELD = enum(VALUE='value',
|
|
COUNT='count')
|
|
|
|
def __init__(self):
|
|
self.book = dict()
|
|
|
|
def reset(self, key: str = None) -> None:
|
|
item = self.book.get(key, {})
|
|
if key is None:
|
|
self.book.clear()
|
|
else:
|
|
item[self.FIELD.VALUE] = 0.
|
|
item[self.FIELD.COUNT] = 0
|
|
self.book[key] = item
|
|
|
|
def update(self, key: str, val: torch.Tensor) -> None:
|
|
item = self.book.get(key, None)
|
|
if item is None:
|
|
self.reset(key)
|
|
self.update(key, val)
|
|
else:
|
|
item[self.FIELD.VALUE] += val
|
|
item[self.FIELD.COUNT] += 1
|
|
|
|
def get(self, key: str) -> float:
|
|
item = self.book.get(key, None)
|
|
assert item is not None
|
|
return item[self.FIELD.VALUE] / float(item[self.FIELD.COUNT]) if float(item[self.FIELD.COUNT]) > 0. else 0.
|
|
|
|
def count(self, key: str) -> float:
|
|
item = self.book.get(key, None)
|
|
assert item is not None
|
|
return item[self.FIELD.COUNT]
|
|
|
|
# Code
|
|
class CommonFeatureLearningLoss(torch.nn.Module):
|
|
def __init__(self, beta: float = 1.0):
|
|
super(CommonFeatureLearningLoss, self).__init__()
|
|
self.beta = beta
|
|
|
|
def forward(self, hs: torch.Tensor, ht: torch.Tensor, ft_: torch.Tensor, ft: torch.Tensor) -> torch.Tensor:
|
|
kl_loss = 0.0
|
|
mse_loss = 0.0
|
|
for ht_i in ht:
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore')
|
|
kl_loss += torch.nn.functional.kl_div(torch.log_softmax(hs, dim=1), torch.softmax(ht_i, dim=1))
|
|
for i in range(len(ft_)):
|
|
mse_loss += torch.nn.functional.mse_loss(ft_[i], ft[i])
|
|
|
|
return kl_loss + self.beta * mse_loss
|
|
|
|
|
|
class ResidualBlock(torch.nn.Module):
|
|
def __init__(self, inplanes: int, planes: int, stride: int = 1):
|
|
super(ResidualBlock, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(in_channels=inplanes,
|
|
out_channels=planes,
|
|
kernel_size=(3, 3),
|
|
stride=stride,
|
|
padding=1,
|
|
bias=False)
|
|
|
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
|
|
self.conv2 = torch.nn.Conv2d(in_channels=planes,
|
|
out_channels=planes,
|
|
kernel_size=(3, 3),
|
|
stride=stride,
|
|
padding=1,
|
|
bias=False)
|
|
self.downsample = None
|
|
if stride > 1 or inplanes != planes:
|
|
self.downsample = torch.nn.Sequential(torch.nn.Conv2d(in_channels=inplanes,
|
|
out_channels=planes,
|
|
kernel_size=(1, 1),
|
|
stride=stride,
|
|
bias=False)
|
|
)
|
|
|
|
self.stride = stride
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
residual = x
|
|
x = self.conv1(x)
|
|
x = self.relu(x)
|
|
x = self.conv2(x)
|
|
|
|
if self.downsample is not None:
|
|
residual = self.downsample(residual)
|
|
|
|
x += residual
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
|
|
class CommonFeatureBlocks(torch.nn.Module):
|
|
def __init__(self, n_student_channels: int, n_teacher_channels: List[int], n_hidden_channel: int):
|
|
super(CommonFeatureBlocks, self).__init__()
|
|
|
|
ch_s = n_student_channels # Readability
|
|
ch_ts = n_teacher_channels # Readability
|
|
ch_h = n_hidden_channel # Readability
|
|
|
|
self.align_t = torch.nn.ModuleList()
|
|
for ch_t in ch_ts:
|
|
self.align_t.append(
|
|
torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_t,
|
|
out_channels=2 * ch_h,
|
|
kernel_size=(1, 1),
|
|
bias=False),
|
|
|
|
torch.nn.ReLU(inplace=True)
|
|
)
|
|
)
|
|
|
|
self.align_s = torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_s,
|
|
out_channels=2 * ch_h,
|
|
kernel_size=(1, 1),
|
|
bias=False),
|
|
|
|
torch.nn.ReLU(inplace=True)
|
|
)
|
|
|
|
self.extractor = torch.nn.Sequential(ResidualBlock(inplanes=2 * ch_h,
|
|
planes=ch_h,
|
|
stride=1),
|
|
|
|
ResidualBlock(inplanes=ch_h,
|
|
planes=ch_h,
|
|
stride=1),
|
|
|
|
ResidualBlock(inplanes=ch_h,
|
|
planes=ch_h,
|
|
stride=1)
|
|
)
|
|
|
|
self.dec_t = torch.nn.ModuleList()
|
|
for ch_t in ch_ts:
|
|
self.dec_t.append(
|
|
torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_h,
|
|
out_channels=ch_t,
|
|
kernel_size=(3, 3),
|
|
stride=1,
|
|
padding=1,
|
|
bias=False),
|
|
|
|
torch.nn.ReLU(inplace=True),
|
|
|
|
torch.nn.Conv2d(in_channels=ch_t,
|
|
out_channels=ch_t,
|
|
kernel_size=(1, 1),
|
|
stride=1,
|
|
padding=0,
|
|
bias=False)
|
|
)
|
|
)
|
|
|
|
def forward(self, fs: torch.Tensor, ft: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
aligned_t = [align(f) for align, f in zip(self.align_t, ft)]
|
|
aligned_s = self.align_s(fs)
|
|
ht = [self.extractor(f) for f in aligned_t]
|
|
hs = self.extractor(aligned_s)
|
|
ft_ = [dec(h) for dec, h in zip(self.dec_t, ht)]
|
|
return hs, ht, ft_
|
|
|
|
class MySimpleModel(torch.nn.Module):
|
|
@property
|
|
def features(self) -> torch.Tensor:
|
|
assert self.handles is not None
|
|
return self.resnet.layer4.output
|
|
|
|
@property
|
|
def feature_dimension(self) -> int:
|
|
return self.resnet.layer4[-1].conv2.out_channels
|
|
|
|
@property
|
|
def soft_output(self) -> torch.Tensor:
|
|
return self.fc.output
|
|
|
|
@property
|
|
def n_output(self) -> int:
|
|
return self.fc[-1].out_features
|
|
|
|
def __init__(self, n_output: int):
|
|
super(MySimpleModel, self).__init__()
|
|
self.handles = {}
|
|
|
|
self.resnet = torchvision.models.resnet18(pretrained=True)
|
|
self.resnet.fc_backup = self.resnet.fc
|
|
self.resnet.fc = torch.nn.Sequential()
|
|
self.fc = torch.nn.Sequential(
|
|
torch.nn.Dropout(0.2),
|
|
torch.nn.Linear(self.resnet.fc_backup.in_features,
|
|
self.resnet.fc_backup.in_features // 2),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(self.resnet.fc_backup.in_features // 2,
|
|
n_output)
|
|
)
|
|
|
|
def register_hooks(self) -> None:
|
|
def forward_hook(module: torch.nn.modules.container.Sequential, _: tuple, output: torch.Tensor):
|
|
module.output = output
|
|
|
|
self.handles['conv_layer'] = self.resnet.layer4.register_forward_hook(forward_hook)
|
|
self.handles['fc_layer'] = self.fc.register_forward_hook(forward_hook)
|
|
|
|
def remove_hooks(self) -> None:
|
|
assert self.handles is not None
|
|
for k, v in self.handles.items():
|
|
self.handles[k].remove()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.resnet(x)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.forward(x)
|
|
return torch.softmax(x, dim=1).argmax(1)
|
|
|
|
|
|
def save_load_best_model(model: MySimpleModel, experience: NCExperience, is_train=True, pbar=None) -> Tuple[MySimpleModel, float]:
|
|
n_task = experience.current_experience
|
|
path = f'./state'
|
|
state_path = f'{path}/cfa_{args.dataset}_{n_task + 1}'
|
|
|
|
if not exists(path):
|
|
mkdir(path)
|
|
if not exists(state_path):
|
|
torch.save(model.state_dict(), state_path)
|
|
assert exists(path)
|
|
assert exists(state_path)
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
corrects = 0
|
|
total_task = 0
|
|
|
|
if not is_train:
|
|
model.load_state_dict(torch.load(state_path))
|
|
|
|
data_loader = torch.utils.data.DataLoader(experience.dataset, batch_size=args.batch_size, shuffle=True)
|
|
|
|
for _, data in enumerate(data_loader):
|
|
x = data[0].to(device)
|
|
y = (data[1] - min(experience.classes_in_this_experience)).to(device)
|
|
total_task += len(x)
|
|
corrects += int(sum(model.predict(x) == y))
|
|
accuracy = (corrects / total_task) if total_task > 0 else 0
|
|
|
|
if is_train:
|
|
description = f'Train accuracy for base task {n_task + 1}: {accuracy * 100:.2f}% ({corrects}/{total_task})'
|
|
|
|
if pbar is None:
|
|
print(description)
|
|
else:
|
|
pbar.set_description_str(description)
|
|
torch.save(model.state_dict(), state_path)
|
|
model.train()
|
|
else:
|
|
model.load_state_dict(torch.load(state_path))
|
|
|
|
return model, accuracy
|
|
|
|
|
|
def amalgamate(teachers: List[MySimpleModel], data_array: List = [], labels: List = [],
|
|
train: NCScenario = None, test: NCScenario = None, epochs: int = 100) -> Tuple[MySimpleModel, List[int], List[int], np.ndarray, np.ndarray]:
|
|
|
|
def memory_keys(all_data: NCScenario, teachers: List[MySimpleModel], labels: List[int], idx: int, previous_data_idxs: List[int] = None) -> List[int]:
|
|
if args.memory_strategy == 'grow':
|
|
n_elem = args.memory_budget // len(labels[idx])
|
|
return get_mean_exemplar_keys(all_data, teachers[idx], labels[idx], n_elem, previous_data_idxs)
|
|
elif args.memory_strategy == 'fixed':
|
|
n_elem = args.memory_budget // sum(len(v) for v in labels)
|
|
return get_mean_exemplar_keys(all_data, teachers[idx], labels[idx], n_elem, previous_data_idxs)
|
|
|
|
def get_mean_exemplar_keys(all_data: NCScenario, teacher: MySimpleModel, labels: List[int], n_elem_per_class: int, previous_data_idxs: List[int] = None) -> List[int]:
|
|
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
|
|
current_batch_size = 0
|
|
n_samples = 0
|
|
label_mean = {}
|
|
|
|
with torch.no_grad():
|
|
for label in labels:
|
|
teacher(torch.rand((1,3,224,224)).to(device))
|
|
label_mean[label] = torch.zeros_like(teacher.features)
|
|
if previous_data_idxs is None:
|
|
for _, [x, y, _] in enumerate(all_data[0].dataset):
|
|
if x is None:
|
|
break
|
|
|
|
if current_batch_size < args.batch_size and x is not None:
|
|
if y == label:
|
|
batch_sample[current_batch_size] = x
|
|
current_batch_size += 1
|
|
elif x is None and current_batch_size == 0:
|
|
break
|
|
elif current_batch_size > 0 or x is None:
|
|
batch_sample = batch_sample[:current_batch_size].to(device)
|
|
teacher(batch_sample)
|
|
label_mean[label] += sum(teacher.features, 1).unsqueeze(0)
|
|
n_samples += current_batch_size
|
|
current_batch_size = 0
|
|
else:
|
|
for _, idx in enumerate(previous_data_idxs):
|
|
x, y, _ = all_data[0].dataset[idx]
|
|
if current_batch_size < args.batch_size:
|
|
if y == label:
|
|
batch_sample[current_batch_size] = x
|
|
current_batch_size += 1
|
|
elif current_batch_size == 0:
|
|
break
|
|
elif current_batch_size > 0:
|
|
batch_sample = batch_sample[:current_batch_size].to(device)
|
|
teacher(batch_sample)
|
|
label_mean[label] += sum(teacher.features, 1).unsqueeze(0)
|
|
n_samples += current_batch_size
|
|
current_batch_size = 0
|
|
|
|
label_mean[label] /= n_samples
|
|
n_samples = 0
|
|
|
|
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
|
|
batch_idx = np.empty(args.batch_size, dtype=int)
|
|
current_batch_index = 0
|
|
label_idx_distance = {}
|
|
label_idx = {}
|
|
|
|
with torch.no_grad():
|
|
for label in labels:
|
|
|
|
label_idx_distance[label] = {}
|
|
label_idx[label] = []
|
|
for idx, [x, y, _] in enumerate(all_data[0].dataset):
|
|
if x is None:
|
|
break
|
|
|
|
if current_batch_index < args.batch_size and x is not None:
|
|
if y == label:
|
|
batch_sample[current_batch_index] = x
|
|
batch_idx[current_batch_index] = idx
|
|
label_idx_distance[label][idx] = np.inf
|
|
current_batch_index += 1
|
|
elif x is None and current_batch_index == 0:
|
|
break
|
|
elif current_batch_index > 0 or x is None:
|
|
batch_sample = batch_sample[:current_batch_index].to(device)
|
|
teacher(batch_sample)
|
|
for idx, elem in enumerate(batch_idx):
|
|
label_idx_distance[label][elem] = float(torch.dist(teacher.features[idx], label_mean[label], 2))
|
|
current_batch_index = 0
|
|
|
|
label_idx[label] = [idx for idx, _ in sorted(label_idx_distance[label].items(), key=lambda x: x[1])][:n_elem_per_class]
|
|
|
|
return np.concatenate([v for k, v in label_idx.items()], 0)
|
|
|
|
def get_conf_keys(all_data: NCScenario, teacher: MySimpleModel, labels: List[int], n_elem: int = args.memory_budget) -> List[int]:
|
|
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
|
|
batch_idx = np.empty(args.batch_size, dtype=int)
|
|
current_batch_size = 0
|
|
conf = {}
|
|
for label in labels:
|
|
conf[label] = {}
|
|
for idx, [x, y, _] in enumerate(all_data[0].dataset):
|
|
if x is None:
|
|
break
|
|
|
|
if current_batch_size < args.batch_size and x is not None:
|
|
if y == label:
|
|
batch_sample[current_batch_size] = x
|
|
batch_idx[current_batch_size] = idx
|
|
conf[label][idx] = 0
|
|
current_batch_size += 1
|
|
elif x is None and current_batch_size == 0:
|
|
break
|
|
elif current_batch_size > 0 or x is None:
|
|
batch_sample = batch_sample[:current_batch_size].to(device)
|
|
batch_idx = batch_idx[:current_batch_size]
|
|
|
|
soft_top_2 = torch.softmax(teacher(batch_sample), 1).topk(2)[0].tolist()
|
|
for i, j in enumerate(batch_idx):
|
|
conf[label][j] = soft_top_2[i][0] - soft_top_2[i][1]
|
|
|
|
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
|
|
batch_idx = np.empty(args.batch_size, dtype=int)
|
|
current_batch_size = 0
|
|
|
|
idxs = []
|
|
for label in conf.keys():
|
|
idxs = idxs + list(dict(sorted(conf[label].items(),
|
|
key=lambda x: x[1],
|
|
reverse=True)).keys())[:(n_elem // len(labels))]
|
|
|
|
return idxs[:n_elem]
|
|
|
|
student = MySimpleModel(sum([teacher.n_output for teacher in teachers])).to(device)
|
|
cfl_blk = CommonFeatureBlocks(student.feature_dimension,
|
|
[teachers[0].feature_dimension, teachers[1].feature_dimension],
|
|
int(sum([teacher.feature_dimension for teacher in teachers])/len(teachers))).to(device)
|
|
|
|
cfl_lr = args.lr * 10 if args.cfl_lr is None else args.cfl_lr
|
|
|
|
params_10x = [param for name, param in student.named_parameters() if 'fc' in name]
|
|
params_1x = [param for name, param in student.named_parameters() if 'fc' not in name]
|
|
optimizer = torch.optim.Adam([{'params': params_1x, 'lr': args.lr},
|
|
{'params': params_10x, 'lr': args.lr * 10},
|
|
{'params': cfl_blk.parameters(), 'lr': cfl_lr}])
|
|
|
|
student.train()
|
|
[teacher.register_hooks() for teacher in teachers]
|
|
[teacher.eval() for teacher in teachers]
|
|
student.register_hooks()
|
|
average_tracker = AverageTracker()
|
|
|
|
common_feature_learning_criterion = CommonFeatureLearningLoss().to(device)
|
|
|
|
print('Adjusting replay memory - sorry the delay, this part of the code is not optimized')
|
|
data_idx = []
|
|
for idx, data in enumerate(data_array):
|
|
data_idx.append(memory_keys(train, teachers, labels, idx, data))
|
|
data_array[idx] = torch.stack([train[0].dataset[idx_][0] for idx_ in data_idx[idx]])
|
|
print('Replay memory adjusted')
|
|
|
|
all_data = torch.cat([data for data in data_array])
|
|
p = torch.randperm(len(all_data))
|
|
all_data = all_data[p]
|
|
|
|
student.eval()
|
|
with torch.no_grad():
|
|
corrects = np.zeros((len(teachers)), int)
|
|
total_samples = np.zeros((len(teachers)), int)
|
|
b_accuracy = np.zeros((len(teachers)))
|
|
labels_ = [label for l in labels for label in l]
|
|
for _, [x, y, _] in enumerate(test[0].dataset):
|
|
if int(y) not in labels_:
|
|
continue
|
|
|
|
label = torch.tensor(labels_.index(y)).to(device)
|
|
sample = x.view(1, 3, 224, 224).to(device)
|
|
pred = student.predict(sample)
|
|
for idx, task_labels in enumerate(labels):
|
|
if label in task_labels:
|
|
corrects[idx] = corrects[idx] + int(pred == label)
|
|
total_samples[idx] = total_samples[idx] + 1
|
|
|
|
for idx, _ in enumerate(teachers):
|
|
b_accuracy[idx] = (corrects[idx] / total_samples[idx]) if total_samples[idx] > 0 else 0
|
|
|
|
student.train()
|
|
with tqdm(unit='Epoch', total=epochs) as pbar:
|
|
while pbar.n < epochs:
|
|
average_tracker.reset()
|
|
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
|
|
current_batch_index = 0
|
|
for _, data in enumerate(all_data):
|
|
if current_batch_index < args.batch_size and data is not None:
|
|
batch_sample[current_batch_index] = data
|
|
current_batch_index += 1
|
|
elif data is None and current_batch_index == 0:
|
|
break
|
|
elif current_batch_index > 0 or data is None:
|
|
batch_sample = batch_sample[:current_batch_index].to(device)
|
|
current_batch_index = 0
|
|
|
|
optimizer.zero_grad()
|
|
with torch.no_grad():
|
|
[teacher(batch_sample) for teacher in teachers]
|
|
teacher_soft = torch.cat(tuple([teacher.soft_output for teacher in teachers]), dim=1)
|
|
|
|
student(batch_sample)
|
|
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
|
|
student_soft = student.soft_output
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore')
|
|
cross_entropy_loss = torch.nn.functional.kl_div(torch.log_softmax(student_soft, dim=1),
|
|
torch.softmax(teacher_soft, dim=1))
|
|
|
|
hs, ht, ft_ = cfl_blk(student.features, [teacher.features for teacher in teachers])
|
|
common_features_loss = 10 * common_feature_learning_criterion(hs, ht, ft_, [teacher.features for teacher in teachers])
|
|
|
|
loss = args.alpha * cross_entropy_loss + (1 - args.alpha) * common_features_loss
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
average_tracker.update('loss', loss.item())
|
|
average_tracker.update('ce', cross_entropy_loss.item())
|
|
average_tracker.update('cf', common_features_loss.item())
|
|
|
|
description = f'Amalgamating ' \
|
|
f'Loss={average_tracker.get("loss"):.2f} '\
|
|
f'(cross entropy={average_tracker.get("ce"):.2f}, '\
|
|
f'common features={average_tracker.get("cf"):.2f})'
|
|
pbar.set_description_str(description)
|
|
pbar.refresh()
|
|
pbar.update()
|
|
all_data = torch.cat([data for data in data_array])
|
|
p = torch.randperm(len(all_data))
|
|
all_data = all_data[p]
|
|
[teacher.remove_hooks() for teacher in teachers]
|
|
student.remove_hooks()
|
|
|
|
student.eval()
|
|
with torch.no_grad():
|
|
corrects = np.zeros((len(teachers)), int)
|
|
total_samples = np.zeros((len(teachers)), int)
|
|
accuracy = np.zeros((len(teachers)))
|
|
labels_ = [label for l in labels for label in l]
|
|
for _, [x, y, _] in enumerate(test[0].dataset):
|
|
if int(y) not in labels_:
|
|
continue
|
|
|
|
label = torch.tensor(labels_.index(y)).to(device)
|
|
sample = x.view(1, 3, 224, 224).to(device)
|
|
pred = student.predict(sample)
|
|
for idx, task_labels in enumerate(labels):
|
|
if label in task_labels:
|
|
corrects[idx] = corrects[idx] + int(pred == label)
|
|
total_samples[idx] = total_samples[idx] + 1
|
|
|
|
for idx, _ in enumerate(teachers):
|
|
accuracy[idx] = (corrects[idx] / total_samples[idx]) if total_samples[idx] > 0 else 0
|
|
|
|
return student, [data for d in data_idx for data in d], accuracy, b_accuracy
|
|
|
|
|
|
def load_dataset(dataset: str, force_unique_task: bool = False) -> Tuple[NCScenario, NCScenario]:
|
|
if dataset in ['mnist', 'usps', 'omniglot']:
|
|
transforms = T.Compose([T.Grayscale(3), T.Resize((224, 224)), T.ToTensor()])
|
|
|
|
if dataset == 'mnist':
|
|
args.n_tasks = 1 if force_unique_task else 5
|
|
data = SplitMNIST(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 10),
|
|
train_transform=transforms, eval_transform=transforms)
|
|
elif dataset == 'usps':
|
|
args.n_tasks = 1 if force_unique_task else 5
|
|
usps_train = torchvision.datasets.USPS(root='./data', train=True, download=True)
|
|
usps_test = torchvision.datasets.USPS(root='./data', train=False, download=True)
|
|
|
|
data = nc_benchmark(usps_train, usps_test, n_experiences=args.n_tasks, seed=args.seed, task_labels=True,
|
|
fixed_class_order=range(0, 10), train_transform=transforms, eval_transform=transforms)
|
|
elif dataset == 'omniglot':
|
|
args.n_tasks = 1 if force_unique_task else 241
|
|
data = SplitOmniglot(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 964),
|
|
train_transform=transforms, eval_transform=transforms)
|
|
|
|
elif dataset in ['cifar10', 'cifar100', 'tinyImageNet10', 'tiny10', 'tinyImageNet20', 'tiny20']:
|
|
transforms = T.Compose([T.Resize((224, 224)), T.ToTensor()])
|
|
|
|
if dataset == 'cifar10':
|
|
args.n_tasks = 1 if force_unique_task else 5
|
|
data = SplitCIFAR10(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 10),
|
|
train_transform=transforms, eval_transform=transforms)
|
|
elif dataset == 'cifar100':
|
|
args.n_tasks = 1 if force_unique_task else 10
|
|
data = SplitCIFAR100(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 100),
|
|
train_transform=transforms, eval_transform=transforms)
|
|
elif dataset in ['tinyImageNet10', 'tiny10']:
|
|
args.n_tasks = 1 if force_unique_task else 10
|
|
data = SplitTinyImageNet(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 200),
|
|
train_transform=transforms, eval_transform=transforms)
|
|
elif dataset in ['tinyImageNet20', 'tiny20']:
|
|
args.n_tasks = 1 if force_unique_task else 20
|
|
data = SplitTinyImageNet(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 200),
|
|
train_transform=transforms, eval_transform=transforms)
|
|
|
|
return data.train_stream, data.test_stream
|
|
|
|
|
|
def main(args):
|
|
path = f'./state'
|
|
state_path = f'{path}/{paper_name}_{args.dataset}'
|
|
is_training_base_model = False
|
|
|
|
# Prepare data
|
|
train_stream, test_stream = load_dataset(args.dataset)
|
|
|
|
|
|
# Training base model
|
|
if args.force_base_retraining is not None and args.force_base_retraining:
|
|
for i in range(args.n_tasks):
|
|
if exists(f'{state_path}_{i + 1}'):
|
|
remove(f'{state_path}_{i + 1}')
|
|
if exists(path):
|
|
for i in range(args.n_tasks):
|
|
if not exists(f'{state_path}_{i + 1}'):
|
|
is_training_base_model = True
|
|
break
|
|
else:
|
|
is_training_base_model = True
|
|
|
|
if is_training_base_model:
|
|
print('Training base model')
|
|
for experience in train_stream:
|
|
model = MySimpleModel(len(experience.classes_in_this_experience)).to(device)
|
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
with tqdm(unit='Epoch', total=args.epochs) as pbar:
|
|
train_loader = torch.utils.data.DataLoader(experience.dataset, batch_size=args.batch_size, shuffle=True)
|
|
while pbar.n < args.epochs:
|
|
model.train()
|
|
for _, data in enumerate(train_loader):
|
|
x = data[0].to(device)
|
|
y = (data[1] - min(experience.classes_in_this_experience)).to(device)
|
|
optimizer.zero_grad()
|
|
output = model(x)
|
|
criterion(output, y).backward()
|
|
optimizer.step()
|
|
save_load_best_model(model, experience, pbar=pbar)
|
|
pbar.update()
|
|
|
|
base_models = []
|
|
experiences = []
|
|
print('Base model performance')
|
|
for experience in test_stream:
|
|
model = MySimpleModel(len(experience.classes_in_this_experience)).to(device)
|
|
model, accuracy = save_load_best_model(model, experience, False)
|
|
print(f'Test accuracy for base task {experience.current_experience + 1} {experience.classes_in_this_experience}: {accuracy * 100:.2f}')
|
|
if experience.current_experience == 0:
|
|
accuracy_0 = accuracy
|
|
base_models.append(model.cpu())
|
|
experiences.append(experience)
|
|
|
|
args.n_tasks_original = args.n_tasks
|
|
n_tasks = args.n_tasks
|
|
accuracies = np.zeros((n_tasks, n_tasks))
|
|
b_accuracies = np.zeros((n_tasks, n_tasks))
|
|
accuracies[0, 0] = accuracy_0
|
|
train_stream, test_stream = load_dataset(args.dataset, True)
|
|
if args.amalgamation_strategy == 'one_at_a_time':
|
|
amalgamated_model, data, accuracy, b_accuracy = amalgamate(teachers=[base_models[0].to(device), base_models[1].to(device)],
|
|
data_array=[None, None],
|
|
labels=[experiences[0].classes_seen_so_far, experiences[1].classes_in_this_experience],
|
|
train=train_stream,
|
|
test=test_stream,
|
|
epochs=args.amalgamation_epochs)
|
|
accuracies[0, 1] = accuracy[0]
|
|
accuracies[1, 1] = accuracy[1]
|
|
b_accuracies[0, 1] = b_accuracy[0]
|
|
b_accuracies[1, 1] = b_accuracy[1]
|
|
if n_tasks > 2:
|
|
for i in range(1, n_tasks):
|
|
amalgamated_model, data, accuracy, b_accuracy = amalgamate(teachers=[amalgamated_model.to(device), base_models[i].to(device)],
|
|
data_array=[data, None],
|
|
labels=[experiences[i-1].classes_seen_so_far, experiences[i].classes_in_this_experience],
|
|
train=train_stream,
|
|
test=test_stream,
|
|
epochs=args.amalgamation_epochs)
|
|
accuracies[i - 1, i] = accuracy[0]
|
|
accuracies[i, i] = accuracy[1]
|
|
b_accuracies[i - 1, i] = b_accuracy[0]
|
|
b_accuracies[i, i] = b_accuracy[1]
|
|
elif args.amalgamation_strategy == 'all_together':
|
|
for n_task in range(2, n_tasks + 1, 1):
|
|
_, _, accuracy, b_accuracy = amalgamate(teachers=[base_models[idx].to(device) for idx in range(n_task)],
|
|
data_array=[None] * n_task,
|
|
labels=[experiences[idx].classes_in_this_experience for idx in range(n_task)],
|
|
train=train_stream,
|
|
test=test_stream,
|
|
epochs=args.amalgamation_epochs)
|
|
for i in range(len(accuracy)):
|
|
accuracies[i, n_task - 1] = accuracy[i]
|
|
b_accuracies[i, n_task - 1] = b_accuracy[i]
|
|
|
|
print(f'accuracies \n {accuracies}')
|
|
print(f'b_accuracies (random initialization) \n {b_accuracies}')
|
|
|
|
acc = np.nanmean(np.where(accuracies != 0, accuracies, np.nan), 0)[-1]
|
|
print(f'ACC: {acc * 100:.2f}%')
|
|
|
|
bwt = 0
|
|
for i in range(n_tasks - 1):
|
|
j = {'one_at_a_time': i + 1, 'all_together': -1}
|
|
bwt += accuracies[i, j[args.amalgamation_strategy]] - accuracies[i, i]
|
|
bwt = bwt / (n_tasks - 1)
|
|
print(f'BWT: {bwt * 100:.2f}%')
|
|
|
|
fwt = 0
|
|
for i in range(1, n_tasks):
|
|
j = {'one_at_a_time': i, 'all_together': -1}
|
|
fwt += accuracies[i - 1, j[args.amalgamation_strategy]] - b_accuracies[i, i]
|
|
fwt = fwt / (n_tasks - 1)
|
|
print(f'FWT: {fwt * 100:.2f}%')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parser.parse_args()
|
|
|
|
# Configure random seed and devices
|
|
if args.seed is None:
|
|
args.seed = random.randint(0, 10000)
|
|
torch.manual_seed(args.seed)
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
|
|
device = torch.device(f'cuda:{args.cuda_device}' if (torch.cuda.is_available() and args.cuda) else 'cpu')
|
|
print(f'Device: {device}')
|
|
|
|
main(args)
|