CFA_ECML-PKDD-2022/cfa.py

767 lines
35 KiB
Python
Raw Permalink Normal View History

2022-06-23 23:48:41 +08:00
import argparse
import ssl
from abc import ABC
import numpy as np
import torch
import torchvision
import torchvision.transforms as T
import random
import warnings
from tqdm import tqdm
from typing import List, Tuple
from os import mkdir, remove
from os.path import exists
import matplotlib.pyplot as plt
from avalanche.benchmarks.classic import SplitMNIST, SplitCIFAR10, SplitCIFAR100, SplitTinyImageNet, SplitOmniglot
from avalanche.benchmarks.generators import nc_benchmark
from avalanche.benchmarks.scenarios.new_classes.nc_scenario import NCExperience, NCScenario
paper_name = 'cfa'
parser = argparse.ArgumentParser(f'./{paper_name}.py',
description='Class-Incremental Learning via Knowledge Amalgamation')
parser.add_argument('--dataset', type=str, default='mnist', help='Dataset to use', required=False,
choices=['usps', 'mnist', 'cifar10', 'cifar100', 'tiny10', 'tiny20', 'omniglot'])
parser.add_argument('--seed', type=int, default=None, metavar='N', help='Set a seed to compare runs')
parser.add_argument('--cuda', action='store_true', help='enable CUDA')
parser.add_argument('--cuda_device', type=int, default=0, help='Cuda device identifier')
# Teacher models configuration
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--batch_size', type=int, default=2 ** 6, help='Batch size for base model training',
choices=[2 ** 3, 2 ** 4, 2 ** 5, 2 ** 6, 2 ** 7, 2 ** 8])
parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs per task (base learning)',
choices=[5, 10, 50, 100, 1000])
parser.add_argument('--force_base_retraining', type=bool, default=True, help='Force base model retraining')
# Amalgamation configuration
parser.add_argument('--cfl_lr', type=float, default=None, help='Common feature amalgamation learning rate')
parser.add_argument('--amalgamation_strategy', type=str, default='all_together', help='Amalgamation Strategy',
choices=['all_together', 'one_at_a_time'])
parser.add_argument('--amalgamation_epochs', type=int, default=100, help='Amalgamation epochs',
choices=[10, 100, 500, 1000])
# Memory strategy configuration
parser.add_argument('--memory_strategy', type=str, default='grow', help='Memory Strategy',
choices=['fixed', 'grow'])
parser.add_argument('--memory_budget', type=int, default=1000, help='Memory Budget',
choices=[100, 200, 500, 1000, 2000])
parser.add_argument('--alpha', type=float, default=0.5, help='Alpha')
# Helpers
def enum(**enums):
return type('Enum', (), enums)
FIELD = enum(EXT_IDX='last_searched_external_idx',
INT_IDX='last_searched_internal_idx',
N_ELEM='n_elem',
CLASSES_LIST='classes_list')
class AverageTracker():
FIELD = enum(VALUE='value',
COUNT='count')
def __init__(self):
self.book = dict()
def reset(self, key: str = None) -> None:
item = self.book.get(key, {})
if key is None:
self.book.clear()
else:
item[self.FIELD.VALUE] = 0.
item[self.FIELD.COUNT] = 0
self.book[key] = item
def update(self, key: str, val: torch.Tensor) -> None:
item = self.book.get(key, None)
if item is None:
self.reset(key)
self.update(key, val)
else:
item[self.FIELD.VALUE] += val
item[self.FIELD.COUNT] += 1
def get(self, key: str) -> float:
item = self.book.get(key, None)
assert item is not None
return item[self.FIELD.VALUE] / float(item[self.FIELD.COUNT]) if float(item[self.FIELD.COUNT]) > 0. else 0.
def count(self, key: str) -> float:
item = self.book.get(key, None)
assert item is not None
return item[self.FIELD.COUNT]
# Code
class CommonFeatureLearningLoss(torch.nn.Module):
def __init__(self, beta: float = 1.0):
super(CommonFeatureLearningLoss, self).__init__()
self.beta = beta
def forward(self, hs: torch.Tensor, ht: torch.Tensor, ft_: torch.Tensor, ft: torch.Tensor) -> torch.Tensor:
kl_loss = 0.0
mse_loss = 0.0
for ht_i in ht:
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
kl_loss += torch.nn.functional.kl_div(torch.log_softmax(hs, dim=1), torch.softmax(ht_i, dim=1))
for i in range(len(ft_)):
mse_loss += torch.nn.functional.mse_loss(ft_[i], ft[i])
return kl_loss + self.beta * mse_loss
class ResidualBlock(torch.nn.Module):
def __init__(self, inplanes: int, planes: int, stride: int = 1):
super(ResidualBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels=inplanes,
out_channels=planes,
kernel_size=(3, 3),
stride=stride,
padding=1,
bias=False)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = torch.nn.Conv2d(in_channels=planes,
out_channels=planes,
kernel_size=(3, 3),
stride=stride,
padding=1,
bias=False)
self.downsample = None
if stride > 1 or inplanes != planes:
self.downsample = torch.nn.Sequential(torch.nn.Conv2d(in_channels=inplanes,
out_channels=planes,
kernel_size=(1, 1),
stride=stride,
bias=False)
)
self.stride = stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.relu(x)
return x
class CommonFeatureBlocks(torch.nn.Module):
def __init__(self, n_student_channels: int, n_teacher_channels: List[int], n_hidden_channel: int):
super(CommonFeatureBlocks, self).__init__()
ch_s = n_student_channels # Readability
ch_ts = n_teacher_channels # Readability
ch_h = n_hidden_channel # Readability
self.align_t = torch.nn.ModuleList()
for ch_t in ch_ts:
self.align_t.append(
torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_t,
out_channels=2 * ch_h,
kernel_size=(1, 1),
bias=False),
torch.nn.ReLU(inplace=True)
)
)
self.align_s = torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_s,
out_channels=2 * ch_h,
kernel_size=(1, 1),
bias=False),
torch.nn.ReLU(inplace=True)
)
self.extractor = torch.nn.Sequential(ResidualBlock(inplanes=2 * ch_h,
planes=ch_h,
stride=1),
ResidualBlock(inplanes=ch_h,
planes=ch_h,
stride=1),
ResidualBlock(inplanes=ch_h,
planes=ch_h,
stride=1)
)
self.dec_t = torch.nn.ModuleList()
for ch_t in ch_ts:
self.dec_t.append(
torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_h,
out_channels=ch_t,
kernel_size=(3, 3),
stride=1,
padding=1,
bias=False),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(in_channels=ch_t,
out_channels=ch_t,
kernel_size=(1, 1),
stride=1,
padding=0,
bias=False)
)
)
def forward(self, fs: torch.Tensor, ft: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
aligned_t = [align(f) for align, f in zip(self.align_t, ft)]
aligned_s = self.align_s(fs)
ht = [self.extractor(f) for f in aligned_t]
hs = self.extractor(aligned_s)
ft_ = [dec(h) for dec, h in zip(self.dec_t, ht)]
return hs, ht, ft_
class MySimpleModel(torch.nn.Module):
@property
def features(self) -> torch.Tensor:
assert self.handles is not None
return self.resnet.layer4.output
@property
def feature_dimension(self) -> int:
return self.resnet.layer4[-1].conv2.out_channels
@property
def soft_output(self) -> torch.Tensor:
return self.fc.output
@property
def n_output(self) -> int:
return self.fc[-1].out_features
def __init__(self, n_output: int):
super(MySimpleModel, self).__init__()
self.handles = {}
self.resnet = torchvision.models.resnet18(pretrained=True)
self.resnet.fc_backup = self.resnet.fc
self.resnet.fc = torch.nn.Sequential()
self.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(self.resnet.fc_backup.in_features,
self.resnet.fc_backup.in_features // 2),
torch.nn.ReLU(),
torch.nn.Linear(self.resnet.fc_backup.in_features // 2,
n_output)
)
def register_hooks(self) -> None:
def forward_hook(module: torch.nn.modules.container.Sequential, _: tuple, output: torch.Tensor):
module.output = output
self.handles['conv_layer'] = self.resnet.layer4.register_forward_hook(forward_hook)
self.handles['fc_layer'] = self.fc.register_forward_hook(forward_hook)
def remove_hooks(self) -> None:
assert self.handles is not None
for k, v in self.handles.items():
self.handles[k].remove()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnet(x)
x = self.fc(x)
return x
def predict(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward(x)
return torch.softmax(x, dim=1).argmax(1)
def save_load_best_model(model: MySimpleModel, experience: NCExperience, is_train=True, pbar=None) -> Tuple[MySimpleModel, float]:
n_task = experience.current_experience
path = f'./state'
state_path = f'{path}/cfa_{args.dataset}_{n_task + 1}'
if not exists(path):
mkdir(path)
if not exists(state_path):
torch.save(model.state_dict(), state_path)
assert exists(path)
assert exists(state_path)
with torch.no_grad():
model.eval()
corrects = 0
total_task = 0
if not is_train:
model.load_state_dict(torch.load(state_path))
data_loader = torch.utils.data.DataLoader(experience.dataset, batch_size=args.batch_size, shuffle=True)
for _, data in enumerate(data_loader):
x = data[0].to(device)
y = (data[1] - min(experience.classes_in_this_experience)).to(device)
total_task += len(x)
corrects += int(sum(model.predict(x) == y))
accuracy = (corrects / total_task) if total_task > 0 else 0
if is_train:
description = f'Train accuracy for base task {n_task + 1}: {accuracy * 100:.2f}% ({corrects}/{total_task})'
if pbar is None:
print(description)
else:
pbar.set_description_str(description)
torch.save(model.state_dict(), state_path)
model.train()
else:
model.load_state_dict(torch.load(state_path))
return model, accuracy
def amalgamate(teachers: List[MySimpleModel], data_array: List = [], labels: List = [],
train: NCScenario = None, test: NCScenario = None, epochs: int = 100) -> Tuple[MySimpleModel, List[int], List[int], np.ndarray, np.ndarray]:
def memory_keys(all_data: NCScenario, teachers: List[MySimpleModel], labels: List[int], idx: int, previous_data_idxs: List[int] = None) -> List[int]:
if args.memory_strategy == 'grow':
n_elem = args.memory_budget // len(labels[idx])
return get_mean_exemplar_keys(all_data, teachers[idx], labels[idx], n_elem, previous_data_idxs)
elif args.memory_strategy == 'fixed':
n_elem = args.memory_budget // sum(len(v) for v in labels)
return get_mean_exemplar_keys(all_data, teachers[idx], labels[idx], n_elem, previous_data_idxs)
def get_mean_exemplar_keys(all_data: NCScenario, teacher: MySimpleModel, labels: List[int], n_elem_per_class: int, previous_data_idxs: List[int] = None) -> List[int]:
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
current_batch_size = 0
n_samples = 0
label_mean = {}
with torch.no_grad():
for label in labels:
teacher(torch.rand((1,3,224,224)).to(device))
label_mean[label] = torch.zeros_like(teacher.features)
if previous_data_idxs is None:
for _, [x, y, _] in enumerate(all_data[0].dataset):
if x is None:
break
if current_batch_size < args.batch_size and x is not None:
if y == label:
batch_sample[current_batch_size] = x
current_batch_size += 1
elif x is None and current_batch_size == 0:
break
elif current_batch_size > 0 or x is None:
batch_sample = batch_sample[:current_batch_size].to(device)
teacher(batch_sample)
label_mean[label] += sum(teacher.features, 1).unsqueeze(0)
n_samples += current_batch_size
current_batch_size = 0
else:
for _, idx in enumerate(previous_data_idxs):
x, y, _ = all_data[0].dataset[idx]
if current_batch_size < args.batch_size:
if y == label:
batch_sample[current_batch_size] = x
current_batch_size += 1
elif current_batch_size == 0:
break
elif current_batch_size > 0:
batch_sample = batch_sample[:current_batch_size].to(device)
teacher(batch_sample)
label_mean[label] += sum(teacher.features, 1).unsqueeze(0)
n_samples += current_batch_size
current_batch_size = 0
label_mean[label] /= n_samples
n_samples = 0
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
batch_idx = np.empty(args.batch_size, dtype=int)
current_batch_index = 0
label_idx_distance = {}
label_idx = {}
with torch.no_grad():
for label in labels:
label_idx_distance[label] = {}
label_idx[label] = []
for idx, [x, y, _] in enumerate(all_data[0].dataset):
if x is None:
break
if current_batch_index < args.batch_size and x is not None:
if y == label:
batch_sample[current_batch_index] = x
batch_idx[current_batch_index] = idx
label_idx_distance[label][idx] = np.inf
current_batch_index += 1
elif x is None and current_batch_index == 0:
break
elif current_batch_index > 0 or x is None:
batch_sample = batch_sample[:current_batch_index].to(device)
teacher(batch_sample)
for idx, elem in enumerate(batch_idx):
label_idx_distance[label][elem] = float(torch.dist(teacher.features[idx], label_mean[label], 2))
current_batch_index = 0
label_idx[label] = [idx for idx, _ in sorted(label_idx_distance[label].items(), key=lambda x: x[1])][:n_elem_per_class]
return np.concatenate([v for k, v in label_idx.items()], 0)
def get_conf_keys(all_data: NCScenario, teacher: MySimpleModel, labels: List[int], n_elem: int = args.memory_budget) -> List[int]:
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
batch_idx = np.empty(args.batch_size, dtype=int)
current_batch_size = 0
conf = {}
for label in labels:
conf[label] = {}
for idx, [x, y, _] in enumerate(all_data[0].dataset):
if x is None:
break
if current_batch_size < args.batch_size and x is not None:
if y == label:
batch_sample[current_batch_size] = x
batch_idx[current_batch_size] = idx
conf[label][idx] = 0
current_batch_size += 1
elif x is None and current_batch_size == 0:
break
elif current_batch_size > 0 or x is None:
batch_sample = batch_sample[:current_batch_size].to(device)
batch_idx = batch_idx[:current_batch_size]
soft_top_2 = torch.softmax(teacher(batch_sample), 1).topk(2)[0].tolist()
for i, j in enumerate(batch_idx):
conf[label][j] = soft_top_2[i][0] - soft_top_2[i][1]
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
batch_idx = np.empty(args.batch_size, dtype=int)
current_batch_size = 0
idxs = []
for label in conf.keys():
idxs = idxs + list(dict(sorted(conf[label].items(),
key=lambda x: x[1],
reverse=True)).keys())[:(n_elem // len(labels))]
return idxs[:n_elem]
student = MySimpleModel(sum([teacher.n_output for teacher in teachers])).to(device)
cfl_blk = CommonFeatureBlocks(student.feature_dimension,
[teachers[0].feature_dimension, teachers[1].feature_dimension],
int(sum([teacher.feature_dimension for teacher in teachers])/len(teachers))).to(device)
cfl_lr = args.lr * 10 if args.cfl_lr is None else args.cfl_lr
params_10x = [param for name, param in student.named_parameters() if 'fc' in name]
params_1x = [param for name, param in student.named_parameters() if 'fc' not in name]
optimizer = torch.optim.Adam([{'params': params_1x, 'lr': args.lr},
{'params': params_10x, 'lr': args.lr * 10},
{'params': cfl_blk.parameters(), 'lr': cfl_lr}])
student.train()
[teacher.register_hooks() for teacher in teachers]
[teacher.eval() for teacher in teachers]
student.register_hooks()
average_tracker = AverageTracker()
common_feature_learning_criterion = CommonFeatureLearningLoss().to(device)
print('Adjusting replay memory - sorry the delay, this part of the code is not optimized')
data_idx = []
for idx, data in enumerate(data_array):
data_idx.append(memory_keys(train, teachers, labels, idx, data))
data_array[idx] = torch.stack([train[0].dataset[idx_][0] for idx_ in data_idx[idx]])
print('Replay memory adjusted')
all_data = torch.cat([data for data in data_array])
p = torch.randperm(len(all_data))
all_data = all_data[p]
student.eval()
with torch.no_grad():
corrects = np.zeros((len(teachers)), int)
total_samples = np.zeros((len(teachers)), int)
b_accuracy = np.zeros((len(teachers)))
labels_ = [label for l in labels for label in l]
for _, [x, y, _] in enumerate(test[0].dataset):
if int(y) not in labels_:
continue
label = torch.tensor(labels_.index(y)).to(device)
sample = x.view(1, 3, 224, 224).to(device)
pred = student.predict(sample)
for idx, task_labels in enumerate(labels):
if label in task_labels:
corrects[idx] = corrects[idx] + int(pred == label)
total_samples[idx] = total_samples[idx] + 1
for idx, _ in enumerate(teachers):
b_accuracy[idx] = (corrects[idx] / total_samples[idx]) if total_samples[idx] > 0 else 0
student.train()
with tqdm(unit='Epoch', total=epochs) as pbar:
while pbar.n < epochs:
average_tracker.reset()
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
current_batch_index = 0
for _, data in enumerate(all_data):
if current_batch_index < args.batch_size and data is not None:
batch_sample[current_batch_index] = data
current_batch_index += 1
elif data is None and current_batch_index == 0:
break
elif current_batch_index > 0 or data is None:
batch_sample = batch_sample[:current_batch_index].to(device)
current_batch_index = 0
optimizer.zero_grad()
with torch.no_grad():
[teacher(batch_sample) for teacher in teachers]
teacher_soft = torch.cat(tuple([teacher.soft_output for teacher in teachers]), dim=1)
student(batch_sample)
batch_sample = torch.empty((args.batch_size, 3, 224, 224))
student_soft = student.soft_output
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
cross_entropy_loss = torch.nn.functional.kl_div(torch.log_softmax(student_soft, dim=1),
torch.softmax(teacher_soft, dim=1))
hs, ht, ft_ = cfl_blk(student.features, [teacher.features for teacher in teachers])
common_features_loss = 10 * common_feature_learning_criterion(hs, ht, ft_, [teacher.features for teacher in teachers])
loss = args.alpha * cross_entropy_loss + (1 - args.alpha) * common_features_loss
loss.backward()
optimizer.step()
average_tracker.update('loss', loss.item())
average_tracker.update('ce', cross_entropy_loss.item())
average_tracker.update('cf', common_features_loss.item())
description = f'Amalgamating ' \
f'Loss={average_tracker.get("loss"):.2f} '\
f'(cross entropy={average_tracker.get("ce"):.2f}, '\
f'common features={average_tracker.get("cf"):.2f})'
pbar.set_description_str(description)
pbar.refresh()
pbar.update()
all_data = torch.cat([data for data in data_array])
p = torch.randperm(len(all_data))
all_data = all_data[p]
[teacher.remove_hooks() for teacher in teachers]
student.remove_hooks()
student.eval()
with torch.no_grad():
corrects = np.zeros((len(teachers)), int)
total_samples = np.zeros((len(teachers)), int)
accuracy = np.zeros((len(teachers)))
labels_ = [label for l in labels for label in l]
for _, [x, y, _] in enumerate(test[0].dataset):
if int(y) not in labels_:
continue
label = torch.tensor(labels_.index(y)).to(device)
sample = x.view(1, 3, 224, 224).to(device)
pred = student.predict(sample)
for idx, task_labels in enumerate(labels):
if label in task_labels:
corrects[idx] = corrects[idx] + int(pred == label)
total_samples[idx] = total_samples[idx] + 1
for idx, _ in enumerate(teachers):
accuracy[idx] = (corrects[idx] / total_samples[idx]) if total_samples[idx] > 0 else 0
return student, [data for d in data_idx for data in d], accuracy, b_accuracy
def load_dataset(dataset: str, force_unique_task: bool = False) -> Tuple[NCScenario, NCScenario]:
if dataset in ['mnist', 'usps', 'omniglot']:
transforms = T.Compose([T.Grayscale(3), T.Resize((224, 224)), T.ToTensor()])
if dataset == 'mnist':
args.n_tasks = 1 if force_unique_task else 5
data = SplitMNIST(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 10),
train_transform=transforms, eval_transform=transforms)
elif dataset == 'usps':
args.n_tasks = 1 if force_unique_task else 5
usps_train = torchvision.datasets.USPS(root='./data', train=True, download=True)
usps_test = torchvision.datasets.USPS(root='./data', train=False, download=True)
data = nc_benchmark(usps_train, usps_test, n_experiences=args.n_tasks, seed=args.seed, task_labels=True,
fixed_class_order=range(0, 10), train_transform=transforms, eval_transform=transforms)
elif dataset == 'omniglot':
args.n_tasks = 1 if force_unique_task else 241
data = SplitOmniglot(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 964),
train_transform=transforms, eval_transform=transforms)
elif dataset in ['cifar10', 'cifar100', 'tinyImageNet10', 'tiny10', 'tinyImageNet20', 'tiny20']:
transforms = T.Compose([T.Resize((224, 224)), T.ToTensor()])
if dataset == 'cifar10':
args.n_tasks = 1 if force_unique_task else 5
data = SplitCIFAR10(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 10),
train_transform=transforms, eval_transform=transforms)
elif dataset == 'cifar100':
args.n_tasks = 1 if force_unique_task else 10
data = SplitCIFAR100(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 100),
train_transform=transforms, eval_transform=transforms)
elif dataset in ['tinyImageNet10', 'tiny10']:
args.n_tasks = 1 if force_unique_task else 10
data = SplitTinyImageNet(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 200),
train_transform=transforms, eval_transform=transforms)
elif dataset in ['tinyImageNet20', 'tiny20']:
args.n_tasks = 1 if force_unique_task else 20
data = SplitTinyImageNet(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 200),
train_transform=transforms, eval_transform=transforms)
return data.train_stream, data.test_stream
def main(args):
path = f'./state'
state_path = f'{path}/{paper_name}_{args.dataset}'
is_training_base_model = False
# Prepare data
train_stream, test_stream = load_dataset(args.dataset)
# Training base model
if args.force_base_retraining is not None and args.force_base_retraining:
for i in range(args.n_tasks):
if exists(f'{state_path}_{i + 1}'):
remove(f'{state_path}_{i + 1}')
if exists(path):
for i in range(args.n_tasks):
if not exists(f'{state_path}_{i + 1}'):
is_training_base_model = True
break
else:
is_training_base_model = True
if is_training_base_model:
print('Training base model')
for experience in train_stream:
model = MySimpleModel(len(experience.classes_in_this_experience)).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
with tqdm(unit='Epoch', total=args.epochs) as pbar:
train_loader = torch.utils.data.DataLoader(experience.dataset, batch_size=args.batch_size, shuffle=True)
while pbar.n < args.epochs:
model.train()
for _, data in enumerate(train_loader):
x = data[0].to(device)
y = (data[1] - min(experience.classes_in_this_experience)).to(device)
optimizer.zero_grad()
output = model(x)
criterion(output, y).backward()
optimizer.step()
save_load_best_model(model, experience, pbar=pbar)
pbar.update()
base_models = []
experiences = []
print('Base model performance')
for experience in test_stream:
model = MySimpleModel(len(experience.classes_in_this_experience)).to(device)
model, accuracy = save_load_best_model(model, experience, False)
print(f'Test accuracy for base task {experience.current_experience + 1} {experience.classes_in_this_experience}: {accuracy * 100:.2f}')
if experience.current_experience == 0:
accuracy_0 = accuracy
base_models.append(model.cpu())
experiences.append(experience)
args.n_tasks_original = args.n_tasks
n_tasks = args.n_tasks
accuracies = np.zeros((n_tasks, n_tasks))
b_accuracies = np.zeros((n_tasks, n_tasks))
accuracies[0, 0] = accuracy_0
train_stream, test_stream = load_dataset(args.dataset, True)
if args.amalgamation_strategy == 'one_at_a_time':
amalgamated_model, data, accuracy, b_accuracy = amalgamate(teachers=[base_models[0].to(device), base_models[1].to(device)],
data_array=[None, None],
labels=[experiences[0].classes_seen_so_far, experiences[1].classes_in_this_experience],
train=train_stream,
test=test_stream,
epochs=args.amalgamation_epochs)
accuracies[0, 1] = accuracy[0]
accuracies[1, 1] = accuracy[1]
b_accuracies[0, 1] = b_accuracy[0]
b_accuracies[1, 1] = b_accuracy[1]
if n_tasks > 2:
for i in range(1, n_tasks):
amalgamated_model, data, accuracy, b_accuracy = amalgamate(teachers=[amalgamated_model.to(device), base_models[i].to(device)],
data_array=[data, None],
labels=[experiences[i-1].classes_seen_so_far, experiences[i].classes_in_this_experience],
train=train_stream,
test=test_stream,
epochs=args.amalgamation_epochs)
accuracies[i - 1, i] = accuracy[0]
accuracies[i, i] = accuracy[1]
b_accuracies[i - 1, i] = b_accuracy[0]
b_accuracies[i, i] = b_accuracy[1]
elif args.amalgamation_strategy == 'all_together':
for n_task in range(2, n_tasks + 1, 1):
_, _, accuracy, b_accuracy = amalgamate(teachers=[base_models[idx].to(device) for idx in range(n_task)],
data_array=[None] * n_task,
labels=[experiences[idx].classes_in_this_experience for idx in range(n_task)],
train=train_stream,
test=test_stream,
epochs=args.amalgamation_epochs)
for i in range(len(accuracy)):
accuracies[i, n_task - 1] = accuracy[i]
b_accuracies[i, n_task - 1] = b_accuracy[i]
print(f'accuracies \n {accuracies}')
print(f'b_accuracies (random initialization) \n {b_accuracies}')
acc = np.nanmean(np.where(accuracies != 0, accuracies, np.nan), 0)[-1]
print(f'ACC: {acc * 100:.2f}%')
bwt = 0
for i in range(n_tasks - 1):
j = {'one_at_a_time': i + 1, 'all_together': -1}
bwt += accuracies[i, j[args.amalgamation_strategy]] - accuracies[i, i]
bwt = bwt / (n_tasks - 1)
print(f'BWT: {bwt * 100:.2f}%')
fwt = 0
for i in range(1, n_tasks):
j = {'one_at_a_time': i, 'all_together': -1}
fwt += accuracies[i - 1, j[args.amalgamation_strategy]] - b_accuracies[i, i]
fwt = fwt / (n_tasks - 1)
print(f'FWT: {fwt * 100:.2f}%')
if __name__ == '__main__':
args = parser.parse_args()
# Configure random seed and devices
if args.seed is None:
args.seed = random.randint(0, 10000)
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
device = torch.device(f'cuda:{args.cuda_device}' if (torch.cuda.is_available() and args.cuda) else 'cpu')
print(f'Device: {device}')
main(args)