Построение нейронных агентов с памятью: дифференцируемая память, мета-обучение и приоритизированный реплей для непрерывной адаптации
Руководство по созданию агента в PyTorch с дифференцируемой памятью, приоритизированным реплеем и мета-обучением, позволяющее адаптироваться к новым задачам без потери предыдущих навыков.
Зачем нужна память для непрерывного обучения
Нейронным агентам, обучающимся на множестве задач, необходимы механизмы для хранения и извлечения прошлых опытов, чтобы новые данные не перезаписывали полезные представления. В этом руководстве мы реализуем агент с расширенной памятью на PyTorch, объединяющий дифференцируемую память в духе DNC, приоритизированный буфер реплея и простой мета-обучающий цикл для быстрой адаптации и уменьшения катастрофического забывания.
Настройки памяти и импорты
Начинаем с загрузки необходимых библиотек и определения конфигурационного класса, который управляет размером памяти, размерностью векторов и количеством голов для чтения/записи. Эти параметры определяют, как будут работать операции адресации и изменения памяти во время обучения.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
from dataclasses import dataclass
@dataclass
class MemoryConfig:
memory_size: int = 128
memory_dim: int = 64
num_read_heads: int = 4
num_write_heads: int = 1Дифференцируемая банковская память и контроллер
NeuralMemoryBank хранит и извлекает вектора с помощью адресации по содержимому. MemoryController использует LSTM, который генерирует ключи чтения/записи, силы и векторы. Контроллер читает из памяти, записывает обновления и формирует выходы, комбинируя свое состояние с прочитанными векторами.
class NeuralMemoryBank(nn.Module):
def __init__(self, config: MemoryConfig):
super().__init__()
self.memory_size = config.memory_size
self.memory_dim = config.memory_dim
self.num_read_heads = config.num_read_heads
self.register_buffer('memory', torch.zeros(config.memory_size, config.memory_dim))
self.register_buffer('usage', torch.zeros(config.memory_size))
def content_addressing(self, key, beta):
key_norm = F.normalize(key, dim=-1)
mem_norm = F.normalize(self.memory, dim=-1)
similarity = torch.matmul(key_norm, mem_norm.t())
return F.softmax(beta * similarity, dim=-1)
def write(self, write_key, write_vector, erase_vector, write_strength):
write_weights = self.content_addressing(write_key, write_strength)
erase = torch.outer(write_weights.squeeze(), erase_vector.squeeze())
self.memory = (self.memory * (1 - erase)).detach()
add = torch.outer(write_weights.squeeze(), write_vector.squeeze())
self.memory = (self.memory + add).detach()
self.usage = (0.99 * self.usage + write_weights.squeeze()).detach()
def read(self, read_keys, read_strengths):
reads = []
for i in range(self.num_read_heads):
weights = self.content_addressing(read_keys[i], read_strengths[i])
read_vector = torch.matmul(weights, self.memory)
reads.append(read_vector)
return torch.cat(reads, dim=-1)
class MemoryController(nn.Module):
def __init__(self, input_dim, hidden_dim, memory_config: MemoryConfig):
super().__init__()
self.hidden_dim = hidden_dim
self.memory_config = memory_config
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
total_read_dim = memory_config.num_read_heads * memory_config.memory_dim
self.read_keys = nn.Linear(hidden_dim, memory_config.num_read_heads * memory_config.memory_dim)
self.read_strengths = nn.Linear(hidden_dim, memory_config.num_read_heads)
self.write_key = nn.Linear(hidden_dim, memory_config.memory_dim)
self.write_vector = nn.Linear(hidden_dim, memory_config.memory_dim)
self.erase_vector = nn.Linear(hidden_dim, memory_config.memory_dim)
self.write_strength = nn.Linear(hidden_dim, 1)
self.output = nn.Linear(hidden_dim + total_read_dim, input_dim)
def forward(self, x, memory_bank, hidden=None):
lstm_out, hidden = self.lstm(x.unsqueeze(0), hidden)
controller_state = lstm_out.squeeze(0)
read_k = self.read_keys(controller_state).view(self.memory_config.num_read_heads, -1)
read_s = F.softplus(self.read_strengths(controller_state))
write_k = self.write_key(controller_state)
write_v = torch.tanh(self.write_vector(controller_state))
erase_v = torch.sigmoid(self.erase_vector(controller_state))
write_s = F.softplus(self.write_strength(controller_state))
read_vectors = memory_bank.read(read_k, read_s)
memory_bank.write(write_k, write_v, erase_v, write_s)
combined = torch.cat([controller_state, read_vectors], dim=-1)
output = self.output(combined)
return output, hiddenЭти классы реализуют основную дифференцируемую память: адресация по содержимому, мягкие операции чтения/записи и взаимодействие контроллера с памятью.
Приоритизированный реплей и мета-ученик
Для повторного использования важных прошлых опытов реализован приоритизированный буфер реплея, который выбирает примеры пропорционально их приоритетам. MetaLearner показывает идею MAML-подобной быстрой адаптации параметров контроллера на поддерживающем наборе данных.
class ExperienceReplay:
def __init__(self, capacity=10000, alpha=0.6):
self.capacity = capacity
self.alpha = alpha
self.buffer = deque(maxlen=capacity)
self.priorities = deque(maxlen=capacity)
def push(self, experience, priority=1.0):
self.buffer.append(experience)
self.priorities.append(priority ** self.alpha)
def sample(self, batch_size, beta=0.4):
if len(self.buffer) == 0:
return [], []
probs = np.array(self.priorities)
probs = probs / probs.sum()
indices = np.random.choice(len(self.buffer), min(batch_size, len(self.buffer)), p=probs, replace=False)
samples = [self.buffer[i] for i in indices]
weights = (len(self.buffer) * probs[indices]) ** (-beta)
weights = weights / weights.max()
return samples, torch.FloatTensor(weights)
class MetaLearner(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def adapt(self, support_x, support_y, num_steps=5, lr=0.01):
adapted_params = {name: param.clone() for name, param in self.model.named_parameters()}
for _ in range(num_steps):
pred, _ = self.model(support_x, self.model.memory_bank)
loss = F.mse_loss(pred, support_y)
grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
adapted_params = {name: param - lr * grad for (name, param), grad in zip(adapted_params.items(), grads)}
return adapted_paramsВажно: для продакшн-реализации MAML потребуется аккуратно подменять параметры и выполнять прямые проходы с адаптированными весами.
Агент непрерывного обучения
Объединяем память, контроллер, реплей и мета-ученик в ContinualLearningAgent. Обучение включает обычные шаги градиентного спуска с добавлением реплей-выборок, усечением градиентов и последующим обновлением оптимизатором; оценка переводит контроллер в режим eval и вычисляет среднюю квадратичную ошибку.
class ContinualLearningAgent:
def __init__(self, input_dim=64, hidden_dim=128):
self.config = MemoryConfig()
self.memory_bank = NeuralMemoryBank(self.config)
self.controller = MemoryController(input_dim, hidden_dim, self.config)
self.replay_buffer = ExperienceReplay(capacity=5000)
self.meta_learner = MetaLearner(self.controller)
self.optimizer = torch.optim.Adam(self.controller.parameters(), lr=0.001)
self.task_history = []
def train_step(self, x, y, use_replay=True):
self.optimizer.zero_grad()
pred, _ = self.controller(x, self.memory_bank)
current_loss = F.mse_loss(pred, y)
self.replay_buffer.push((x.detach().clone(), y.detach().clone()), priority=current_loss.item() + 1e-6)
total_loss = current_loss
if use_replay and len(self.replay_buffer.buffer) > 16:
samples, weights = self.replay_buffer.sample(8)
for (replay_x, replay_y), weight in zip(samples, weights):
with torch.enable_grad():
replay_pred, _ = self.controller(replay_x, self.memory_bank)
replay_loss = F.mse_loss(replay_pred, replay_y)
total_loss = total_loss + 0.3 * replay_loss * weight
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.controller.parameters(), 1.0)
self.optimizer.step()
return total_loss.item()
def evaluate(self, test_data):
self.controller.eval()
total_error = 0
with torch.no_grad():
for x, y in test_data:
pred, _ = self.controller(x, self.memory_bank)
total_error += F.mse_loss(pred, y).item()
self.controller.train()
return total_error / len(test_data)Синтетические задачи и демонстрация
Генератор задач создает разные задачи (синус, косинус, tanh) для проверки способности агента к непрерывной адаптации. Цикл демонстрации последовательно обучает на задачах, оценивает на предыдущих и визуализирует состояние памяти и кривые ошибок.
def create_task_data(task_id, num_samples=100):
torch.manual_seed(task_id)
x = torch.randn(num_samples, 64)
if task_id == 0:
y = torch.sin(x.mean(dim=1, keepdim=True).expand(-1, 64))
elif task_id == 1:
y = torch.cos(x.mean(dim=1, keepdim=True).expand(-1, 64)) * 0.5
else:
y = torch.tanh(x * 0.5 + task_id)
return [(x[i], y[i]) for i in range(num_samples)]
def run_continual_learning_demo():
print(" Neural Memory Agent - Continual Learning Demo\n")
print("=" * 60)
agent = ContinualLearningAgent()
num_tasks = 4
results = {'tasks': [], 'without_memory': [], 'with_memory': []}
for task_id in range(num_tasks):
print(f"\n Learning Task {task_id + 1}/{num_tasks}")
train_data = create_task_data(task_id, num_samples=50)
test_data = create_task_data(task_id, num_samples=20)
for epoch in range(20):
total_loss = 0
for x, y in train_data:
loss = agent.train_step(x, y, use_replay=(task_id > 0))
total_loss += loss
if epoch % 5 == 0:
avg_loss = total_loss / len(train_data)
print(f" Epoch {epoch:2d}: Loss = {avg_loss:.4f}")
print(f"\n Evaluation on all tasks:")
for eval_task_id in range(task_id + 1):
eval_data = create_task_data(eval_task_id, num_samples=20)
error = agent.evaluate(eval_data)
print(f" Task {eval_task_id + 1}: Error = {error:.4f}")
if eval_task_id == task_id:
results['tasks'].append(eval_task_id + 1)
results['with_memory'].append(error)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
ax = axes[0]
memory_matrix = agent.memory_bank.memory.detach().numpy()
im = ax.imshow(memory_matrix, aspect='auto', cmap='viridis')
ax.set_title('Neural Memory Bank State', fontsize=14, fontweight='bold')
ax.set_xlabel('Memory Dimension')
ax.set_ylabel('Memory Slots')
plt.colorbar(im, ax=ax)
ax = axes[1]
ax.plot(results['tasks'], results['with_memory'], marker='o', linewidth=2, markersize=8, label='With Memory Replay')
ax.set_title('Continual Learning Performance', fontsize=14, fontweight='bold')
ax.set_xlabel('Task Number')
ax.set_ylabel('Test Error')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('neural_memory_results.png', dpi=150, bbox_inches='tight')
print("\n Results saved to 'neural_memory_results.png'")
plt.show()
print("\n" + "=" * 60)
print(" Key Insights:")
print(" • Memory bank stores compressed task representations")
print(" • Experience replay mitigates catastrophic forgetting")
print(" • Agent maintains performance on earlier tasks")
print(" • Content-based addressing enables efficient retrieval")
if __name__ == "__main__":
run_continual_learning_demo()Практические выводы
- Приоритизированный реплей помогает модели повторно посещать примеры с высокой ошибкой, которые наиболее подвержены забыванию.
- Адресация по содержимому эффективна для извлечения релевантных представлений при наличии общей структуры между задачами.
- Мета-обучение предоставляет механизм для быстрой адаптации, но требует аккуратной реализации для контроля вычислительных затрат и устойчивости градиентов.
- Визуализация матрицы памяти и кривых ошибок по задачам даёт полезные диагностические сигналы о сохранении знаний при добавлении новых задач.
Switch Language
Read this article in English