diff --git a/cfa.py b/cfa.py new file mode 100644 index 0000000..aee21e2 --- /dev/null +++ b/cfa.py @@ -0,0 +1,766 @@ +import argparse +import ssl +from abc import ABC + +import numpy as np +import torch +import torchvision +import torchvision.transforms as T +import random +import warnings +from tqdm import tqdm +from typing import List, Tuple +from os import mkdir, remove +from os.path import exists +import matplotlib.pyplot as plt + +from avalanche.benchmarks.classic import SplitMNIST, SplitCIFAR10, SplitCIFAR100, SplitTinyImageNet, SplitOmniglot +from avalanche.benchmarks.generators import nc_benchmark +from avalanche.benchmarks.scenarios.new_classes.nc_scenario import NCExperience, NCScenario + +paper_name = 'cfa' + +parser = argparse.ArgumentParser(f'./{paper_name}.py', + description='Class-Incremental Learning via Knowledge Amalgamation') +parser.add_argument('--dataset', type=str, default='mnist', help='Dataset to use', required=False, + choices=['usps', 'mnist', 'cifar10', 'cifar100', 'tiny10', 'tiny20', 'omniglot']) +parser.add_argument('--seed', type=int, default=None, metavar='N', help='Set a seed to compare runs') +parser.add_argument('--cuda', action='store_true', help='enable CUDA') +parser.add_argument('--cuda_device', type=int, default=0, help='Cuda device identifier') + +# Teacher models configuration +parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--batch_size', type=int, default=2 ** 6, help='Batch size for base model training', + choices=[2 ** 3, 2 ** 4, 2 ** 5, 2 ** 6, 2 ** 7, 2 ** 8]) +parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs per task (base learning)', + choices=[5, 10, 50, 100, 1000]) +parser.add_argument('--force_base_retraining', type=bool, default=True, help='Force base model retraining') + +# Amalgamation configuration +parser.add_argument('--cfl_lr', type=float, default=None, help='Common feature amalgamation learning rate') +parser.add_argument('--amalgamation_strategy', type=str, default='all_together', help='Amalgamation Strategy', + choices=['all_together', 'one_at_a_time']) +parser.add_argument('--amalgamation_epochs', type=int, default=100, help='Amalgamation epochs', + choices=[10, 100, 500, 1000]) + +# Memory strategy configuration +parser.add_argument('--memory_strategy', type=str, default='grow', help='Memory Strategy', + choices=['fixed', 'grow']) +parser.add_argument('--memory_budget', type=int, default=1000, help='Memory Budget', + choices=[100, 200, 500, 1000, 2000]) +parser.add_argument('--alpha', type=float, default=0.5, help='Alpha') + + +# Helpers +def enum(**enums): + return type('Enum', (), enums) + + +FIELD = enum(EXT_IDX='last_searched_external_idx', + INT_IDX='last_searched_internal_idx', + N_ELEM='n_elem', + CLASSES_LIST='classes_list') + + +class AverageTracker(): + FIELD = enum(VALUE='value', + COUNT='count') + + def __init__(self): + self.book = dict() + + def reset(self, key: str = None) -> None: + item = self.book.get(key, {}) + if key is None: + self.book.clear() + else: + item[self.FIELD.VALUE] = 0. + item[self.FIELD.COUNT] = 0 + self.book[key] = item + + def update(self, key: str, val: torch.Tensor) -> None: + item = self.book.get(key, None) + if item is None: + self.reset(key) + self.update(key, val) + else: + item[self.FIELD.VALUE] += val + item[self.FIELD.COUNT] += 1 + + def get(self, key: str) -> float: + item = self.book.get(key, None) + assert item is not None + return item[self.FIELD.VALUE] / float(item[self.FIELD.COUNT]) if float(item[self.FIELD.COUNT]) > 0. else 0. + + def count(self, key: str) -> float: + item = self.book.get(key, None) + assert item is not None + return item[self.FIELD.COUNT] + +# Code +class CommonFeatureLearningLoss(torch.nn.Module): + def __init__(self, beta: float = 1.0): + super(CommonFeatureLearningLoss, self).__init__() + self.beta = beta + + def forward(self, hs: torch.Tensor, ht: torch.Tensor, ft_: torch.Tensor, ft: torch.Tensor) -> torch.Tensor: + kl_loss = 0.0 + mse_loss = 0.0 + for ht_i in ht: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore') + kl_loss += torch.nn.functional.kl_div(torch.log_softmax(hs, dim=1), torch.softmax(ht_i, dim=1)) + for i in range(len(ft_)): + mse_loss += torch.nn.functional.mse_loss(ft_[i], ft[i]) + + return kl_loss + self.beta * mse_loss + + +class ResidualBlock(torch.nn.Module): + def __init__(self, inplanes: int, planes: int, stride: int = 1): + super(ResidualBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=inplanes, + out_channels=planes, + kernel_size=(3, 3), + stride=stride, + padding=1, + bias=False) + + self.relu = torch.nn.ReLU(inplace=True) + + self.conv2 = torch.nn.Conv2d(in_channels=planes, + out_channels=planes, + kernel_size=(3, 3), + stride=stride, + padding=1, + bias=False) + self.downsample = None + if stride > 1 or inplanes != planes: + self.downsample = torch.nn.Sequential(torch.nn.Conv2d(in_channels=inplanes, + out_channels=planes, + kernel_size=(1, 1), + stride=stride, + bias=False) + ) + + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + + if self.downsample is not None: + residual = self.downsample(residual) + + x += residual + x = self.relu(x) + return x + + +class CommonFeatureBlocks(torch.nn.Module): + def __init__(self, n_student_channels: int, n_teacher_channels: List[int], n_hidden_channel: int): + super(CommonFeatureBlocks, self).__init__() + + ch_s = n_student_channels # Readability + ch_ts = n_teacher_channels # Readability + ch_h = n_hidden_channel # Readability + + self.align_t = torch.nn.ModuleList() + for ch_t in ch_ts: + self.align_t.append( + torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_t, + out_channels=2 * ch_h, + kernel_size=(1, 1), + bias=False), + + torch.nn.ReLU(inplace=True) + ) + ) + + self.align_s = torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_s, + out_channels=2 * ch_h, + kernel_size=(1, 1), + bias=False), + + torch.nn.ReLU(inplace=True) + ) + + self.extractor = torch.nn.Sequential(ResidualBlock(inplanes=2 * ch_h, + planes=ch_h, + stride=1), + + ResidualBlock(inplanes=ch_h, + planes=ch_h, + stride=1), + + ResidualBlock(inplanes=ch_h, + planes=ch_h, + stride=1) + ) + + self.dec_t = torch.nn.ModuleList() + for ch_t in ch_ts: + self.dec_t.append( + torch.nn.Sequential(torch.nn.Conv2d(in_channels=ch_h, + out_channels=ch_t, + kernel_size=(3, 3), + stride=1, + padding=1, + bias=False), + + torch.nn.ReLU(inplace=True), + + torch.nn.Conv2d(in_channels=ch_t, + out_channels=ch_t, + kernel_size=(1, 1), + stride=1, + padding=0, + bias=False) + ) + ) + + def forward(self, fs: torch.Tensor, ft: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + aligned_t = [align(f) for align, f in zip(self.align_t, ft)] + aligned_s = self.align_s(fs) + ht = [self.extractor(f) for f in aligned_t] + hs = self.extractor(aligned_s) + ft_ = [dec(h) for dec, h in zip(self.dec_t, ht)] + return hs, ht, ft_ + +class MySimpleModel(torch.nn.Module): + @property + def features(self) -> torch.Tensor: + assert self.handles is not None + return self.resnet.layer4.output + + @property + def feature_dimension(self) -> int: + return self.resnet.layer4[-1].conv2.out_channels + + @property + def soft_output(self) -> torch.Tensor: + return self.fc.output + + @property + def n_output(self) -> int: + return self.fc[-1].out_features + + def __init__(self, n_output: int): + super(MySimpleModel, self).__init__() + self.handles = {} + + self.resnet = torchvision.models.resnet18(pretrained=True) + self.resnet.fc_backup = self.resnet.fc + self.resnet.fc = torch.nn.Sequential() + self.fc = torch.nn.Sequential( + torch.nn.Dropout(0.2), + torch.nn.Linear(self.resnet.fc_backup.in_features, + self.resnet.fc_backup.in_features // 2), + torch.nn.ReLU(), + torch.nn.Linear(self.resnet.fc_backup.in_features // 2, + n_output) + ) + + def register_hooks(self) -> None: + def forward_hook(module: torch.nn.modules.container.Sequential, _: tuple, output: torch.Tensor): + module.output = output + + self.handles['conv_layer'] = self.resnet.layer4.register_forward_hook(forward_hook) + self.handles['fc_layer'] = self.fc.register_forward_hook(forward_hook) + + def remove_hooks(self) -> None: + assert self.handles is not None + for k, v in self.handles.items(): + self.handles[k].remove() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.resnet(x) + x = self.fc(x) + return x + + def predict(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward(x) + return torch.softmax(x, dim=1).argmax(1) + + +def save_load_best_model(model: MySimpleModel, experience: NCExperience, is_train=True, pbar=None) -> Tuple[MySimpleModel, float]: + n_task = experience.current_experience + path = f'./state' + state_path = f'{path}/cfa_{args.dataset}_{n_task + 1}' + + if not exists(path): + mkdir(path) + if not exists(state_path): + torch.save(model.state_dict(), state_path) + assert exists(path) + assert exists(state_path) + + with torch.no_grad(): + model.eval() + corrects = 0 + total_task = 0 + + if not is_train: + model.load_state_dict(torch.load(state_path)) + + data_loader = torch.utils.data.DataLoader(experience.dataset, batch_size=args.batch_size, shuffle=True) + + for _, data in enumerate(data_loader): + x = data[0].to(device) + y = (data[1] - min(experience.classes_in_this_experience)).to(device) + total_task += len(x) + corrects += int(sum(model.predict(x) == y)) + accuracy = (corrects / total_task) if total_task > 0 else 0 + + if is_train: + description = f'Train accuracy for base task {n_task + 1}: {accuracy * 100:.2f}% ({corrects}/{total_task})' + + if pbar is None: + print(description) + else: + pbar.set_description_str(description) + torch.save(model.state_dict(), state_path) + model.train() + else: + model.load_state_dict(torch.load(state_path)) + + return model, accuracy + + +def amalgamate(teachers: List[MySimpleModel], data_array: List = [], labels: List = [], + train: NCScenario = None, test: NCScenario = None, epochs: int = 100) -> Tuple[MySimpleModel, List[int], List[int], np.ndarray, np.ndarray]: + + def memory_keys(all_data: NCScenario, teachers: List[MySimpleModel], labels: List[int], idx: int, previous_data_idxs: List[int] = None) -> List[int]: + if args.memory_strategy == 'grow': + n_elem = args.memory_budget // len(labels[idx]) + return get_mean_exemplar_keys(all_data, teachers[idx], labels[idx], n_elem, previous_data_idxs) + elif args.memory_strategy == 'fixed': + n_elem = args.memory_budget // sum(len(v) for v in labels) + return get_mean_exemplar_keys(all_data, teachers[idx], labels[idx], n_elem, previous_data_idxs) + + def get_mean_exemplar_keys(all_data: NCScenario, teacher: MySimpleModel, labels: List[int], n_elem_per_class: int, previous_data_idxs: List[int] = None) -> List[int]: + batch_sample = torch.empty((args.batch_size, 3, 224, 224)) + current_batch_size = 0 + n_samples = 0 + label_mean = {} + + with torch.no_grad(): + for label in labels: + teacher(torch.rand((1,3,224,224)).to(device)) + label_mean[label] = torch.zeros_like(teacher.features) + if previous_data_idxs is None: + for _, [x, y, _] in enumerate(all_data[0].dataset): + if x is None: + break + + if current_batch_size < args.batch_size and x is not None: + if y == label: + batch_sample[current_batch_size] = x + current_batch_size += 1 + elif x is None and current_batch_size == 0: + break + elif current_batch_size > 0 or x is None: + batch_sample = batch_sample[:current_batch_size].to(device) + teacher(batch_sample) + label_mean[label] += sum(teacher.features, 1).unsqueeze(0) + n_samples += current_batch_size + current_batch_size = 0 + else: + for _, idx in enumerate(previous_data_idxs): + x, y, _ = all_data[0].dataset[idx] + if current_batch_size < args.batch_size: + if y == label: + batch_sample[current_batch_size] = x + current_batch_size += 1 + elif current_batch_size == 0: + break + elif current_batch_size > 0: + batch_sample = batch_sample[:current_batch_size].to(device) + teacher(batch_sample) + label_mean[label] += sum(teacher.features, 1).unsqueeze(0) + n_samples += current_batch_size + current_batch_size = 0 + + label_mean[label] /= n_samples + n_samples = 0 + + batch_sample = torch.empty((args.batch_size, 3, 224, 224)) + batch_idx = np.empty(args.batch_size, dtype=int) + current_batch_index = 0 + label_idx_distance = {} + label_idx = {} + + with torch.no_grad(): + for label in labels: + + label_idx_distance[label] = {} + label_idx[label] = [] + for idx, [x, y, _] in enumerate(all_data[0].dataset): + if x is None: + break + + if current_batch_index < args.batch_size and x is not None: + if y == label: + batch_sample[current_batch_index] = x + batch_idx[current_batch_index] = idx + label_idx_distance[label][idx] = np.inf + current_batch_index += 1 + elif x is None and current_batch_index == 0: + break + elif current_batch_index > 0 or x is None: + batch_sample = batch_sample[:current_batch_index].to(device) + teacher(batch_sample) + for idx, elem in enumerate(batch_idx): + label_idx_distance[label][elem] = float(torch.dist(teacher.features[idx], label_mean[label], 2)) + current_batch_index = 0 + + label_idx[label] = [idx for idx, _ in sorted(label_idx_distance[label].items(), key=lambda x: x[1])][:n_elem_per_class] + + return np.concatenate([v for k, v in label_idx.items()], 0) + + def get_conf_keys(all_data: NCScenario, teacher: MySimpleModel, labels: List[int], n_elem: int = args.memory_budget) -> List[int]: + batch_sample = torch.empty((args.batch_size, 3, 224, 224)) + batch_idx = np.empty(args.batch_size, dtype=int) + current_batch_size = 0 + conf = {} + for label in labels: + conf[label] = {} + for idx, [x, y, _] in enumerate(all_data[0].dataset): + if x is None: + break + + if current_batch_size < args.batch_size and x is not None: + if y == label: + batch_sample[current_batch_size] = x + batch_idx[current_batch_size] = idx + conf[label][idx] = 0 + current_batch_size += 1 + elif x is None and current_batch_size == 0: + break + elif current_batch_size > 0 or x is None: + batch_sample = batch_sample[:current_batch_size].to(device) + batch_idx = batch_idx[:current_batch_size] + + soft_top_2 = torch.softmax(teacher(batch_sample), 1).topk(2)[0].tolist() + for i, j in enumerate(batch_idx): + conf[label][j] = soft_top_2[i][0] - soft_top_2[i][1] + + batch_sample = torch.empty((args.batch_size, 3, 224, 224)) + batch_idx = np.empty(args.batch_size, dtype=int) + current_batch_size = 0 + + idxs = [] + for label in conf.keys(): + idxs = idxs + list(dict(sorted(conf[label].items(), + key=lambda x: x[1], + reverse=True)).keys())[:(n_elem // len(labels))] + + return idxs[:n_elem] + + student = MySimpleModel(sum([teacher.n_output for teacher in teachers])).to(device) + cfl_blk = CommonFeatureBlocks(student.feature_dimension, + [teachers[0].feature_dimension, teachers[1].feature_dimension], + int(sum([teacher.feature_dimension for teacher in teachers])/len(teachers))).to(device) + + cfl_lr = args.lr * 10 if args.cfl_lr is None else args.cfl_lr + + params_10x = [param for name, param in student.named_parameters() if 'fc' in name] + params_1x = [param for name, param in student.named_parameters() if 'fc' not in name] + optimizer = torch.optim.Adam([{'params': params_1x, 'lr': args.lr}, + {'params': params_10x, 'lr': args.lr * 10}, + {'params': cfl_blk.parameters(), 'lr': cfl_lr}]) + + student.train() + [teacher.register_hooks() for teacher in teachers] + [teacher.eval() for teacher in teachers] + student.register_hooks() + average_tracker = AverageTracker() + + common_feature_learning_criterion = CommonFeatureLearningLoss().to(device) + + print('Adjusting replay memory - sorry the delay, this part of the code is not optimized') + data_idx = [] + for idx, data in enumerate(data_array): + data_idx.append(memory_keys(train, teachers, labels, idx, data)) + data_array[idx] = torch.stack([train[0].dataset[idx_][0] for idx_ in data_idx[idx]]) + print('Replay memory adjusted') + + all_data = torch.cat([data for data in data_array]) + p = torch.randperm(len(all_data)) + all_data = all_data[p] + + student.eval() + with torch.no_grad(): + corrects = np.zeros((len(teachers)), int) + total_samples = np.zeros((len(teachers)), int) + b_accuracy = np.zeros((len(teachers))) + labels_ = [label for l in labels for label in l] + for _, [x, y, _] in enumerate(test[0].dataset): + if int(y) not in labels_: + continue + + label = torch.tensor(labels_.index(y)).to(device) + sample = x.view(1, 3, 224, 224).to(device) + pred = student.predict(sample) + for idx, task_labels in enumerate(labels): + if label in task_labels: + corrects[idx] = corrects[idx] + int(pred == label) + total_samples[idx] = total_samples[idx] + 1 + + for idx, _ in enumerate(teachers): + b_accuracy[idx] = (corrects[idx] / total_samples[idx]) if total_samples[idx] > 0 else 0 + + student.train() + with tqdm(unit='Epoch', total=epochs) as pbar: + while pbar.n < epochs: + average_tracker.reset() + batch_sample = torch.empty((args.batch_size, 3, 224, 224)) + current_batch_index = 0 + for _, data in enumerate(all_data): + if current_batch_index < args.batch_size and data is not None: + batch_sample[current_batch_index] = data + current_batch_index += 1 + elif data is None and current_batch_index == 0: + break + elif current_batch_index > 0 or data is None: + batch_sample = batch_sample[:current_batch_index].to(device) + current_batch_index = 0 + + optimizer.zero_grad() + with torch.no_grad(): + [teacher(batch_sample) for teacher in teachers] + teacher_soft = torch.cat(tuple([teacher.soft_output for teacher in teachers]), dim=1) + + student(batch_sample) + batch_sample = torch.empty((args.batch_size, 3, 224, 224)) + student_soft = student.soft_output + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore') + cross_entropy_loss = torch.nn.functional.kl_div(torch.log_softmax(student_soft, dim=1), + torch.softmax(teacher_soft, dim=1)) + + hs, ht, ft_ = cfl_blk(student.features, [teacher.features for teacher in teachers]) + common_features_loss = 10 * common_feature_learning_criterion(hs, ht, ft_, [teacher.features for teacher in teachers]) + + loss = args.alpha * cross_entropy_loss + (1 - args.alpha) * common_features_loss + loss.backward() + optimizer.step() + + average_tracker.update('loss', loss.item()) + average_tracker.update('ce', cross_entropy_loss.item()) + average_tracker.update('cf', common_features_loss.item()) + + description = f'Amalgamating ' \ + f'Loss={average_tracker.get("loss"):.2f} '\ + f'(cross entropy={average_tracker.get("ce"):.2f}, '\ + f'common features={average_tracker.get("cf"):.2f})' + pbar.set_description_str(description) + pbar.refresh() + pbar.update() + all_data = torch.cat([data for data in data_array]) + p = torch.randperm(len(all_data)) + all_data = all_data[p] + [teacher.remove_hooks() for teacher in teachers] + student.remove_hooks() + + student.eval() + with torch.no_grad(): + corrects = np.zeros((len(teachers)), int) + total_samples = np.zeros((len(teachers)), int) + accuracy = np.zeros((len(teachers))) + labels_ = [label for l in labels for label in l] + for _, [x, y, _] in enumerate(test[0].dataset): + if int(y) not in labels_: + continue + + label = torch.tensor(labels_.index(y)).to(device) + sample = x.view(1, 3, 224, 224).to(device) + pred = student.predict(sample) + for idx, task_labels in enumerate(labels): + if label in task_labels: + corrects[idx] = corrects[idx] + int(pred == label) + total_samples[idx] = total_samples[idx] + 1 + + for idx, _ in enumerate(teachers): + accuracy[idx] = (corrects[idx] / total_samples[idx]) if total_samples[idx] > 0 else 0 + + return student, [data for d in data_idx for data in d], accuracy, b_accuracy + + +def load_dataset(dataset: str, force_unique_task: bool = False) -> Tuple[NCScenario, NCScenario]: + if dataset in ['mnist', 'usps', 'omniglot']: + transforms = T.Compose([T.Grayscale(3), T.Resize((224, 224)), T.ToTensor()]) + + if dataset == 'mnist': + args.n_tasks = 1 if force_unique_task else 5 + data = SplitMNIST(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 10), + train_transform=transforms, eval_transform=transforms) + elif dataset == 'usps': + args.n_tasks = 1 if force_unique_task else 5 + usps_train = torchvision.datasets.USPS(root='./data', train=True, download=True) + usps_test = torchvision.datasets.USPS(root='./data', train=False, download=True) + + data = nc_benchmark(usps_train, usps_test, n_experiences=args.n_tasks, seed=args.seed, task_labels=True, + fixed_class_order=range(0, 10), train_transform=transforms, eval_transform=transforms) + elif dataset == 'omniglot': + args.n_tasks = 1 if force_unique_task else 241 + data = SplitOmniglot(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 964), + train_transform=transforms, eval_transform=transforms) + + elif dataset in ['cifar10', 'cifar100', 'tinyImageNet10', 'tiny10', 'tinyImageNet20', 'tiny20']: + transforms = T.Compose([T.Resize((224, 224)), T.ToTensor()]) + + if dataset == 'cifar10': + args.n_tasks = 1 if force_unique_task else 5 + data = SplitCIFAR10(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 10), + train_transform=transforms, eval_transform=transforms) + elif dataset == 'cifar100': + args.n_tasks = 1 if force_unique_task else 10 + data = SplitCIFAR100(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 100), + train_transform=transforms, eval_transform=transforms) + elif dataset in ['tinyImageNet10', 'tiny10']: + args.n_tasks = 1 if force_unique_task else 10 + data = SplitTinyImageNet(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 200), + train_transform=transforms, eval_transform=transforms) + elif dataset in ['tinyImageNet20', 'tiny20']: + args.n_tasks = 1 if force_unique_task else 20 + data = SplitTinyImageNet(n_experiences=args.n_tasks, seed=args.seed, fixed_class_order=range(0, 200), + train_transform=transforms, eval_transform=transforms) + + return data.train_stream, data.test_stream + + +def main(args): + path = f'./state' + state_path = f'{path}/{paper_name}_{args.dataset}' + is_training_base_model = False + + # Prepare data + train_stream, test_stream = load_dataset(args.dataset) + + + # Training base model + if args.force_base_retraining is not None and args.force_base_retraining: + for i in range(args.n_tasks): + if exists(f'{state_path}_{i + 1}'): + remove(f'{state_path}_{i + 1}') + if exists(path): + for i in range(args.n_tasks): + if not exists(f'{state_path}_{i + 1}'): + is_training_base_model = True + break + else: + is_training_base_model = True + + if is_training_base_model: + print('Training base model') + for experience in train_stream: + model = MySimpleModel(len(experience.classes_in_this_experience)).to(device) + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + with tqdm(unit='Epoch', total=args.epochs) as pbar: + train_loader = torch.utils.data.DataLoader(experience.dataset, batch_size=args.batch_size, shuffle=True) + while pbar.n < args.epochs: + model.train() + for _, data in enumerate(train_loader): + x = data[0].to(device) + y = (data[1] - min(experience.classes_in_this_experience)).to(device) + optimizer.zero_grad() + output = model(x) + criterion(output, y).backward() + optimizer.step() + save_load_best_model(model, experience, pbar=pbar) + pbar.update() + + base_models = [] + experiences = [] + print('Base model performance') + for experience in test_stream: + model = MySimpleModel(len(experience.classes_in_this_experience)).to(device) + model, accuracy = save_load_best_model(model, experience, False) + print(f'Test accuracy for base task {experience.current_experience + 1} {experience.classes_in_this_experience}: {accuracy * 100:.2f}') + if experience.current_experience == 0: + accuracy_0 = accuracy + base_models.append(model.cpu()) + experiences.append(experience) + + args.n_tasks_original = args.n_tasks + n_tasks = args.n_tasks + accuracies = np.zeros((n_tasks, n_tasks)) + b_accuracies = np.zeros((n_tasks, n_tasks)) + accuracies[0, 0] = accuracy_0 + train_stream, test_stream = load_dataset(args.dataset, True) + if args.amalgamation_strategy == 'one_at_a_time': + amalgamated_model, data, accuracy, b_accuracy = amalgamate(teachers=[base_models[0].to(device), base_models[1].to(device)], + data_array=[None, None], + labels=[experiences[0].classes_seen_so_far, experiences[1].classes_in_this_experience], + train=train_stream, + test=test_stream, + epochs=args.amalgamation_epochs) + accuracies[0, 1] = accuracy[0] + accuracies[1, 1] = accuracy[1] + b_accuracies[0, 1] = b_accuracy[0] + b_accuracies[1, 1] = b_accuracy[1] + if n_tasks > 2: + for i in range(1, n_tasks): + amalgamated_model, data, accuracy, b_accuracy = amalgamate(teachers=[amalgamated_model.to(device), base_models[i].to(device)], + data_array=[data, None], + labels=[experiences[i-1].classes_seen_so_far, experiences[i].classes_in_this_experience], + train=train_stream, + test=test_stream, + epochs=args.amalgamation_epochs) + accuracies[i - 1, i] = accuracy[0] + accuracies[i, i] = accuracy[1] + b_accuracies[i - 1, i] = b_accuracy[0] + b_accuracies[i, i] = b_accuracy[1] + elif args.amalgamation_strategy == 'all_together': + for n_task in range(2, n_tasks + 1, 1): + _, _, accuracy, b_accuracy = amalgamate(teachers=[base_models[idx].to(device) for idx in range(n_task)], + data_array=[None] * n_task, + labels=[experiences[idx].classes_in_this_experience for idx in range(n_task)], + train=train_stream, + test=test_stream, + epochs=args.amalgamation_epochs) + for i in range(len(accuracy)): + accuracies[i, n_task - 1] = accuracy[i] + b_accuracies[i, n_task - 1] = b_accuracy[i] + + print(f'accuracies \n {accuracies}') + print(f'b_accuracies (random initialization) \n {b_accuracies}') + + acc = np.nanmean(np.where(accuracies != 0, accuracies, np.nan), 0)[-1] + print(f'ACC: {acc * 100:.2f}%') + + bwt = 0 + for i in range(n_tasks - 1): + j = {'one_at_a_time': i + 1, 'all_together': -1} + bwt += accuracies[i, j[args.amalgamation_strategy]] - accuracies[i, i] + bwt = bwt / (n_tasks - 1) + print(f'BWT: {bwt * 100:.2f}%') + + fwt = 0 + for i in range(1, n_tasks): + j = {'one_at_a_time': i, 'all_together': -1} + fwt += accuracies[i - 1, j[args.amalgamation_strategy]] - b_accuracies[i, i] + fwt = fwt / (n_tasks - 1) + print(f'FWT: {fwt * 100:.2f}%') + + +if __name__ == '__main__': + args = parser.parse_args() + + # Configure random seed and devices + if args.seed is None: + args.seed = random.randint(0, 10000) + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + device = torch.device(f'cuda:{args.cuda_device}' if (torch.cuda.is_available() and args.cuda) else 'cpu') + print(f'Device: {device}') + + main(args)