<НА ГЛАВНУЮ

Построение нейронных агентов с памятью: дифференцируемая память, мета-обучение и приоритизированный реплей для непрерывной адаптации

Руководство по созданию агента в PyTorch с дифференцируемой памятью, приоритизированным реплеем и мета-обучением, позволяющее адаптироваться к новым задачам без потери предыдущих навыков.

Зачем нужна память для непрерывного обучения

Нейронным агентам, обучающимся на множестве задач, необходимы механизмы для хранения и извлечения прошлых опытов, чтобы новые данные не перезаписывали полезные представления. В этом руководстве мы реализуем агент с расширенной памятью на PyTorch, объединяющий дифференцируемую память в духе DNC, приоритизированный буфер реплея и простой мета-обучающий цикл для быстрой адаптации и уменьшения катастрофического забывания.

Настройки памяти и импорты

Начинаем с загрузки необходимых библиотек и определения конфигурационного класса, который управляет размером памяти, размерностью векторов и количеством голов для чтения/записи. Эти параметры определяют, как будут работать операции адресации и изменения памяти во время обучения.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
from dataclasses import dataclass
 
 
@dataclass
class MemoryConfig:
   memory_size: int = 128
   memory_dim: int = 64
   num_read_heads: int = 4
   num_write_heads: int = 1

Дифференцируемая банковская память и контроллер

NeuralMemoryBank хранит и извлекает вектора с помощью адресации по содержимому. MemoryController использует LSTM, который генерирует ключи чтения/записи, силы и векторы. Контроллер читает из памяти, записывает обновления и формирует выходы, комбинируя свое состояние с прочитанными векторами.

class NeuralMemoryBank(nn.Module):
   def __init__(self, config: MemoryConfig):
       super().__init__()
       self.memory_size = config.memory_size
       self.memory_dim = config.memory_dim
       self.num_read_heads = config.num_read_heads
       self.register_buffer('memory', torch.zeros(config.memory_size, config.memory_dim))
       self.register_buffer('usage', torch.zeros(config.memory_size))
   def content_addressing(self, key, beta):
       key_norm = F.normalize(key, dim=-1)
       mem_norm = F.normalize(self.memory, dim=-1)
       similarity = torch.matmul(key_norm, mem_norm.t())
       return F.softmax(beta * similarity, dim=-1)
   def write(self, write_key, write_vector, erase_vector, write_strength):
       write_weights = self.content_addressing(write_key, write_strength)
       erase = torch.outer(write_weights.squeeze(), erase_vector.squeeze())
       self.memory = (self.memory * (1 - erase)).detach()
       add = torch.outer(write_weights.squeeze(), write_vector.squeeze())
       self.memory = (self.memory + add).detach()
       self.usage = (0.99 * self.usage + write_weights.squeeze()).detach()
   def read(self, read_keys, read_strengths):
       reads = []
       for i in range(self.num_read_heads):
           weights = self.content_addressing(read_keys[i], read_strengths[i])
           read_vector = torch.matmul(weights, self.memory)
           reads.append(read_vector)
       return torch.cat(reads, dim=-1)
 
 
class MemoryController(nn.Module):
   def __init__(self, input_dim, hidden_dim, memory_config: MemoryConfig):
       super().__init__()
       self.hidden_dim = hidden_dim
       self.memory_config = memory_config
       self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
       total_read_dim = memory_config.num_read_heads * memory_config.memory_dim
       self.read_keys = nn.Linear(hidden_dim, memory_config.num_read_heads * memory_config.memory_dim)
       self.read_strengths = nn.Linear(hidden_dim, memory_config.num_read_heads)
       self.write_key = nn.Linear(hidden_dim, memory_config.memory_dim)
       self.write_vector = nn.Linear(hidden_dim, memory_config.memory_dim)
       self.erase_vector = nn.Linear(hidden_dim, memory_config.memory_dim)
       self.write_strength = nn.Linear(hidden_dim, 1)
       self.output = nn.Linear(hidden_dim + total_read_dim, input_dim)
   def forward(self, x, memory_bank, hidden=None):
       lstm_out, hidden = self.lstm(x.unsqueeze(0), hidden)
       controller_state = lstm_out.squeeze(0)
       read_k = self.read_keys(controller_state).view(self.memory_config.num_read_heads, -1)
       read_s = F.softplus(self.read_strengths(controller_state))
       write_k = self.write_key(controller_state)
       write_v = torch.tanh(self.write_vector(controller_state))
       erase_v = torch.sigmoid(self.erase_vector(controller_state))
       write_s = F.softplus(self.write_strength(controller_state))
       read_vectors = memory_bank.read(read_k, read_s)
       memory_bank.write(write_k, write_v, erase_v, write_s)
       combined = torch.cat([controller_state, read_vectors], dim=-1)
       output = self.output(combined)
       return output, hidden

Эти классы реализуют основную дифференцируемую память: адресация по содержимому, мягкие операции чтения/записи и взаимодействие контроллера с памятью.

Приоритизированный реплей и мета-ученик

Для повторного использования важных прошлых опытов реализован приоритизированный буфер реплея, который выбирает примеры пропорционально их приоритетам. MetaLearner показывает идею MAML-подобной быстрой адаптации параметров контроллера на поддерживающем наборе данных.

class ExperienceReplay:
   def __init__(self, capacity=10000, alpha=0.6):
       self.capacity = capacity
       self.alpha = alpha
       self.buffer = deque(maxlen=capacity)
       self.priorities = deque(maxlen=capacity)
   def push(self, experience, priority=1.0):
       self.buffer.append(experience)
       self.priorities.append(priority ** self.alpha)
   def sample(self, batch_size, beta=0.4):
       if len(self.buffer) == 0:
           return [], []
       probs = np.array(self.priorities)
       probs = probs / probs.sum()
       indices = np.random.choice(len(self.buffer), min(batch_size, len(self.buffer)), p=probs, replace=False)
       samples = [self.buffer[i] for i in indices]
       weights = (len(self.buffer) * probs[indices]) ** (-beta)
       weights = weights / weights.max()
       return samples, torch.FloatTensor(weights)
 
 
class MetaLearner(nn.Module):
   def __init__(self, model):
       super().__init__()
       self.model = model
   def adapt(self, support_x, support_y, num_steps=5, lr=0.01):
       adapted_params = {name: param.clone() for name, param in self.model.named_parameters()}
       for _ in range(num_steps):
           pred, _ = self.model(support_x, self.model.memory_bank)
           loss = F.mse_loss(pred, support_y)
           grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
           adapted_params = {name: param - lr * grad for (name, param), grad in zip(adapted_params.items(), grads)}
       return adapted_params

Важно: для продакшн-реализации MAML потребуется аккуратно подменять параметры и выполнять прямые проходы с адаптированными весами.

Агент непрерывного обучения

Объединяем память, контроллер, реплей и мета-ученик в ContinualLearningAgent. Обучение включает обычные шаги градиентного спуска с добавлением реплей-выборок, усечением градиентов и последующим обновлением оптимизатором; оценка переводит контроллер в режим eval и вычисляет среднюю квадратичную ошибку.

class ContinualLearningAgent:
   def __init__(self, input_dim=64, hidden_dim=128):
       self.config = MemoryConfig()
       self.memory_bank = NeuralMemoryBank(self.config)
       self.controller = MemoryController(input_dim, hidden_dim, self.config)
       self.replay_buffer = ExperienceReplay(capacity=5000)
       self.meta_learner = MetaLearner(self.controller)
       self.optimizer = torch.optim.Adam(self.controller.parameters(), lr=0.001)
       self.task_history = []
   def train_step(self, x, y, use_replay=True):
       self.optimizer.zero_grad()
       pred, _ = self.controller(x, self.memory_bank)
       current_loss = F.mse_loss(pred, y)
       self.replay_buffer.push((x.detach().clone(), y.detach().clone()), priority=current_loss.item() + 1e-6)
       total_loss = current_loss
       if use_replay and len(self.replay_buffer.buffer) > 16:
           samples, weights = self.replay_buffer.sample(8)
           for (replay_x, replay_y), weight in zip(samples, weights):
               with torch.enable_grad():
                   replay_pred, _ = self.controller(replay_x, self.memory_bank)
                   replay_loss = F.mse_loss(replay_pred, replay_y)
                   total_loss = total_loss + 0.3 * replay_loss * weight
       total_loss.backward()
       torch.nn.utils.clip_grad_norm_(self.controller.parameters(), 1.0)
       self.optimizer.step()
       return total_loss.item()
   def evaluate(self, test_data):
       self.controller.eval()
       total_error = 0
       with torch.no_grad():
           for x, y in test_data:
               pred, _ = self.controller(x, self.memory_bank)
               total_error += F.mse_loss(pred, y).item()
       self.controller.train()
       return total_error / len(test_data)

Синтетические задачи и демонстрация

Генератор задач создает разные задачи (синус, косинус, tanh) для проверки способности агента к непрерывной адаптации. Цикл демонстрации последовательно обучает на задачах, оценивает на предыдущих и визуализирует состояние памяти и кривые ошибок.

def create_task_data(task_id, num_samples=100):
   torch.manual_seed(task_id)
   x = torch.randn(num_samples, 64)
   if task_id == 0:
       y = torch.sin(x.mean(dim=1, keepdim=True).expand(-1, 64))
   elif task_id == 1:
       y = torch.cos(x.mean(dim=1, keepdim=True).expand(-1, 64)) * 0.5
   else:
       y = torch.tanh(x * 0.5 + task_id)
   return [(x[i], y[i]) for i in range(num_samples)]
 
 
def run_continual_learning_demo():
   print(" Neural Memory Agent - Continual Learning Demo\n")
   print("=" * 60)
   agent = ContinualLearningAgent()
   num_tasks = 4
   results = {'tasks': [], 'without_memory': [], 'with_memory': []}
   for task_id in range(num_tasks):
       print(f"\n Learning Task {task_id + 1}/{num_tasks}")
       train_data = create_task_data(task_id, num_samples=50)
       test_data = create_task_data(task_id, num_samples=20)
       for epoch in range(20):
           total_loss = 0
           for x, y in train_data:
               loss = agent.train_step(x, y, use_replay=(task_id > 0))
               total_loss += loss
           if epoch % 5 == 0:
               avg_loss = total_loss / len(train_data)
               print(f"  Epoch {epoch:2d}: Loss = {avg_loss:.4f}")
       print(f"\n   Evaluation on all tasks:")
       for eval_task_id in range(task_id + 1):
           eval_data = create_task_data(eval_task_id, num_samples=20)
           error = agent.evaluate(eval_data)
           print(f"    Task {eval_task_id + 1}: Error = {error:.4f}")
           if eval_task_id == task_id:
               results['tasks'].append(eval_task_id + 1)
               results['with_memory'].append(error)
   fig, axes = plt.subplots(1, 2, figsize=(14, 5))
   ax = axes[0]
   memory_matrix = agent.memory_bank.memory.detach().numpy()
   im = ax.imshow(memory_matrix, aspect='auto', cmap='viridis')
   ax.set_title('Neural Memory Bank State', fontsize=14, fontweight='bold')
   ax.set_xlabel('Memory Dimension')
   ax.set_ylabel('Memory Slots')
   plt.colorbar(im, ax=ax)
   ax = axes[1]
   ax.plot(results['tasks'], results['with_memory'], marker='o', linewidth=2, markersize=8, label='With Memory Replay')
   ax.set_title('Continual Learning Performance', fontsize=14, fontweight='bold')
   ax.set_xlabel('Task Number')
   ax.set_ylabel('Test Error')
   ax.legend()
   ax.grid(True, alpha=0.3)
   plt.tight_layout()
   plt.savefig('neural_memory_results.png', dpi=150, bbox_inches='tight')
   print("\n Results saved to 'neural_memory_results.png'")
   plt.show()
   print("\n" + "=" * 60)
   print(" Key Insights:")
   print("  • Memory bank stores compressed task representations")
   print("  • Experience replay mitigates catastrophic forgetting")
   print("  • Agent maintains performance on earlier tasks")
   print("  • Content-based addressing enables efficient retrieval")
 
 
if __name__ == "__main__":
   run_continual_learning_demo()

Практические выводы

  • Приоритизированный реплей помогает модели повторно посещать примеры с высокой ошибкой, которые наиболее подвержены забыванию.
  • Адресация по содержимому эффективна для извлечения релевантных представлений при наличии общей структуры между задачами.
  • Мета-обучение предоставляет механизм для быстрой адаптации, но требует аккуратной реализации для контроля вычислительных затрат и устойчивости градиентов.
  • Визуализация матрицы памяти и кривых ошибок по задачам даёт полезные диагностические сигналы о сохранении знаний при добавлении новых задач.
🇬🇧

Switch Language

Read this article in English

Switch to English