Fix ATL_Python. Now working great.
This commit is contained in:
parent
8f306f8888
commit
db3494855e
25
ATL.py
25
ATL.py
|
@ -99,7 +99,6 @@ def width_evolution(network: NeuralNetwork, x: torch.tensor, y: torch.tensor = N
|
|||
network.forward_pass(x)
|
||||
network.run_agmm(x, y)
|
||||
|
||||
|
||||
network.feedforward(x, y)
|
||||
network.width_adaptation_stepwise(y)
|
||||
|
||||
|
@ -143,7 +142,7 @@ def test(network: NeuralNetwork, x: torch.tensor, y: torch.tensor = None, is_sou
|
|||
metrics['reconstruction_target_loss'].append(float(network.loss_value))
|
||||
|
||||
|
||||
def force_same_size(a_tensor, b_tensor, shuffle=True, strategy='max'):
|
||||
def force_same_size(a_tensor, b_tensor, shuffle=True, strategy='min'):
|
||||
common = np.min([a_tensor.shape[0], b_tensor.shape[0]])
|
||||
|
||||
if shuffle:
|
||||
|
@ -174,7 +173,7 @@ def kl(ae: NeuralNetwork, x_source: torch.tensor, x_target: torch.tensor):
|
|||
|
||||
ae.reset_grad()
|
||||
kl_loss = torch.nn.functional.kl_div(ae.forward_pass(x_target).layer_value[1],
|
||||
ae.forward_pass(x_source).layer_value[1], reduction='batchmean')
|
||||
ae.forward_pass(x_source).layer_value[1])
|
||||
|
||||
kl_loss.backward()
|
||||
ae.weight[0] = ae.weight[0] - ae.learning_rate * ae.weight[0].grad
|
||||
|
@ -359,6 +358,18 @@ def ATL(epochs: int = 1, n_batch: int = 1000, device='cpu'):
|
|||
np.mean(metrics['classification_rate_target']) * 100,
|
||||
np.min(metrics['classification_rate_target']) * 100,
|
||||
metrics['classification_rate_target'][-1] * 100))
|
||||
print(('%s %s %s %s AGMM Source:' + Fore.GREEN + ' %d ' + Fore.YELLOW + '%f' + Style.RESET_ALL + Fore.RED + ' %d' + Fore.BLUE + ' %d' + Style.RESET_ALL) % (
|
||||
string_max, string_mean, string_min, string_now,
|
||||
np.max(metrics['agmm_source_size_by_batch']),
|
||||
np.mean(metrics['agmm_source_size_by_batch']),
|
||||
np.min(metrics['agmm_source_size_by_batch']),
|
||||
metrics['agmm_source_size_by_batch'][-1]))
|
||||
print(('%s %s %s %s AGMM Target:' + Fore.GREEN + ' %d ' + Fore.YELLOW + '%f' + Style.RESET_ALL + Fore.RED + ' %d' + Fore.BLUE + ' %d' + Style.RESET_ALL) % (
|
||||
string_max, string_mean, string_min, string_now,
|
||||
np.max(metrics['agmm_target_size_by_batch']),
|
||||
np.mean(metrics['agmm_target_size_by_batch']),
|
||||
np.min(metrics['agmm_target_size_by_batch']),
|
||||
metrics['agmm_target_size_by_batch'][-1]))
|
||||
print(('%s %s %s %s Classification Source Loss:' + Fore.GREEN + ' %f' + Fore.YELLOW + ' %f' + Fore.RED + ' %f' + Fore.BLUE + ' %f' + Style.RESET_ALL) % (
|
||||
string_max, string_mean, string_min, string_now,
|
||||
np.max(metrics['classification_source_loss']),
|
||||
|
@ -377,7 +388,7 @@ def ATL(epochs: int = 1, n_batch: int = 1000, device='cpu'):
|
|||
np.mean(metrics['reconstruction_target_loss']),
|
||||
np.min(metrics['reconstruction_target_loss']),
|
||||
metrics['reconstruction_target_loss'][-1]))
|
||||
print(('%s %s %s %s Kullback-Leibler loss 1:' + Fore.GREEN + ' %f' + Fore.YELLOW + ' %f' + Fore.RED + ' %f' + Fore.BLUE + ' %f' + Style.RESET_ALL) % (
|
||||
print(('%s %s %s %s Kullback-Leibler Loss:' + Fore.GREEN + ' %f' + Fore.YELLOW + ' %f' + Fore.RED + ' %f' + Fore.BLUE + ' %f' + Style.RESET_ALL) % (
|
||||
string_max, string_mean, string_min, string_now,
|
||||
np.max(metrics['kl_loss']),
|
||||
np.mean(metrics['kl_loss']),
|
||||
|
@ -414,7 +425,7 @@ def ATL(epochs: int = 1, n_batch: int = 1000, device='cpu'):
|
|||
dm.load_custom_csv()
|
||||
dm.normalize()
|
||||
|
||||
dm.split_as_source_target_streams(n_batch, 'dallas_2', 0.5)
|
||||
dm.split_as_source_target_streams(n_batch, 0.5)
|
||||
|
||||
nn = NeuralNetwork([dm.number_features, 1, dm.number_classes])
|
||||
ae = DenoisingAutoEncoder([nn.layers[0], nn.layers[1], nn.layers[0]])
|
||||
|
@ -445,14 +456,14 @@ def ATL(epochs: int = 1, n_batch: int = 1000, device='cpu'):
|
|||
metrics['train_time'].append(time.time())
|
||||
for epoch in range(epochs):
|
||||
for x, y in [(x.view(1, x.shape[0]), y.view(1, y.shape[0])) for x, y in zip(Xs, ys)]:
|
||||
width_evolution(network=nn, x=x, y=y, agmm=agmm_source_discriminative, train_agmm=True if epoch == 1 else False)
|
||||
width_evolution(network=nn, x=x, y=y, agmm=agmm_source_discriminative, train_agmm=True if epoch == 0 else False)
|
||||
if not grow_nodes(nn, ae): prune_nodes(nn, ae)
|
||||
discriminative(network=nn, x=x, y=y, agmm=agmm_source_discriminative)
|
||||
|
||||
copy_weights(source=nn, target=ae, layer_numbers=[1])
|
||||
|
||||
for x in [x.view(1, x.shape[0]) for x in Xt]:
|
||||
width_evolution(network=ae, x=x, agmm=agmm_target_generative, train_agmm=True if epoch == 1 else False)
|
||||
width_evolution(network=ae, x=x, agmm=agmm_target_generative, train_agmm=True if epoch == 0 else False)
|
||||
if not grow_nodes(ae, nn): prune_nodes(ae, nn)
|
||||
generative(network=ae, x=x, agmm=agmm_target_generative)
|
||||
|
||||
|
|
|
@ -84,19 +84,9 @@ class DataManipulator:
|
|||
def normalize_image(self):
|
||||
raise TypeError('Not implemented')
|
||||
|
||||
def split_as_source_target_streams(self, number_fold_elements=0, method=None, sampling_ratio=0.5):
|
||||
if number_fold_elements == 0:
|
||||
self.number_fold_elements == self.data.shape[0]
|
||||
else:
|
||||
self.number_fold_elements = number_fold_elements
|
||||
|
||||
if method == None or method == 'none' or method == 'None':
|
||||
self.__split_as_source_target_streams_none(self.number_fold_elements, sampling_ratio)
|
||||
elif method == 'dallas_1' or method == 'dallas1':
|
||||
self.__split_as_source_target_streams_dallas_1(self.number_fold_elements, sampling_ratio)
|
||||
elif method == 'dallas_2' or method == 'dallas2':
|
||||
self.__split_as_source_target_streams_dallas_2(self.number_fold_elements, sampling_ratio)
|
||||
|
||||
def split_as_source_target_streams(self, number_fold_elements=0, sampling_ratio=0.5):
|
||||
self.number_fold_elements = number_fold_elements if number_fold_elements is not 0 else self.data.shape[0]
|
||||
self.__split_as_source_target_streams_dallas_2(self.number_fold_elements, sampling_ratio)
|
||||
self.__create_Xs_ys_Xt_yt()
|
||||
|
||||
def get_Xs(self, number_minibatch):
|
||||
|
|
Loading…
Reference in New Issue