<НА ГЛАВНУЮ

Обучение model-native агента: внутреннее планирование, память и использование нескольких инструментов через end-to-end RL

'Компактный нейросетевой агент учится планировать, хранить и комбинировать символические операции end-to-end с помощью RL, показывая сходящиеся многошаговые стратегии на синтетических арифметических задачах.'

Задача и символические инструменты

В этом руководстве показано, как построить компактного model-native агента, который интегрирует планирование, кратковременную память и композицию «инструментов» внутри одной нейронной модели. Экспериментальная среда — небольшой синтетический арифметический мир, где дискретные действия выступают символическими инструментами (умножить, сложить, вычесть, сохранить, восстановить, ответ). Агент должен выбрать последовательность инструментов, чтобы получить правильный ответ. Учебная кривая (curriculum) помогает модели открыть многошаговые стратегии и использование памяти с помощью end-to-end reinforcement learning.

Окружение и семантика инструментов

Окружение реализует выбор контекстов и функцию оценки последовательностей действий, которая интерпретирует инструменты и возвращает награду. Ниже код окружения и токенов, использованный в эксперименте:

import math, random, torch, torch.nn as nn, torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"; torch.manual_seed(0); random.seed(0)
V = 18; CTX = 10; MUL, ADD, SUB, ANS, STO, RCL, EOS = 11, 12, 13, 14, 15, 16, 17
tok2str = {**{i: str(i) for i in range(10)}, CTX:"[CTX]", MUL:"[MUL]", ADD:"[ADD]", SUB:"[SUB]", ANS:"[ANS]", STO:"[STO]", RCL:"[RCL]", EOS:"[EOS]"}
 
 
class ToolEnv:
   def __init__(self, max_steps=7):
       self.max_steps = max_steps
   def sample(self, stage):
       a,b,c,d,e = [random.randint(0,9) for _ in range(5)]
       if stage==0: ctx=[a,b,c]; target=a*b+c
       elif stage==1: ctx=[a,b,c,d]; target=(a*b+c)-d
       else: ctx=[a,b,c,d,e]; target=(a*b+c)-(d*e)
       return ctx, target, (a,b,c,d,e)
   def step_seq(self, actions, abc, stage):
       a,b,c,d,e = abc; last=None; mem=None; steps=0; shaped=0.0
       goal0=a*b; goal1=goal0+c; goal2=goal1-d; goal3=d*e; goal4=goal1-goal3
       for act in actions:
           steps+=1
           if act==MUL: last=(a*b if last is None else last*(d if stage>0 else 1))
           elif act==ADD and last is not None: last+=c
           elif act==SUB and last is not None:
               last -= (e if stage==2 and mem=="use_d" else (d if stage>0 else 0))
           elif act==STO: mem="use_d" if stage>=1 else "ok"
           elif act==RCL and mem is not None:
               last = (d*e) if (stage==2 and mem=="use_d") else (last if last else 0)
           elif act==ANS:
               target=[goal0,goal1,goal2,goal4][stage] if stage==2 else [goal0,goal1,goal2][stage]
               correct=(last==target)
               if stage==0: shaped += 0.25*(last==goal0)+0.5*(last==goal1)
               if stage==1: shaped += 0.25*(last==goal0)+0.5*(last==goal1)+0.75*(last==goal2)
               if stage==2: shaped += 0.2*(last==goal0)+0.4*(last==goal1)+0.6*(last==goal4)+0.6*(last==goal3)
               return (1.0 if correct else 0.0)+0.2*shaped, steps
           if steps>=self.max_steps: break
       return 0.0, steps

Окружение задаёт три стадии сложности: простая композиция, вычитание после композиции и более сложная композиция с хранением и восстановлением промежуточного результата. Награда комбинирует бинарную правильность и поощрения за достижение промежуточных целей.

Архитектура модели: stage-aware actor-critic

Политика — actor-critic с эмбеддингами и небольшой GRU. Вектор стадии добавляется к эмбеддингу контекста, чтобы одна сеть могла адаптироваться к разной сложности задач. На выходе сеть генерирует фиксированную по длине последовательность токенов-инструментов.

class ActorCritic(nn.Module):
   def __init__(self,V,d=96,nstage=3):
       super().__init__()
       self.emb=nn.Embedding(V,d); self.stage_emb=nn.Embedding(nstage,d)
       self.rnn=nn.GRU(d,d,1,batch_first=True); self.pi=nn.Linear(d,V); self.v=nn.Linear(d,1)
   def forward(self,ctx,stage,max_len=6,greedy=False):
       B=ctx.shape[0]; ce=self.emb(ctx).mean(1)+self.stage_emb(stage).unsqueeze(1)
       h=torch.tanh(ce.mean(1)).unsqueeze(0); inp=self.emb(torch.full((B,1),CTX,device=device))
       acts,logps,ents,vals=[],[],[],[]
       for _ in range(max_len):
           out,h=self.rnn(inp,h); val=self.v(out[:,-1]); logits=self.pi(out[:,-1])
           pi=F.log_softmax(logits,dim=-1).exp(); ent=-(pi*torch.log(pi+1e-9)).sum(1)
           a=torch.argmax(logits,1) if greedy else torch.distributions.Categorical(pi).sample()
           logp=F.log_softmax(logits,dim=-1).gather(1,a.unsqueeze(1)).squeeze(1)
           inp=self.emb(a.unsqueeze(1))
           acts.append(a); logps.append(logp); ents.append(ent); vals.append(val.squeeze(1))
       return torch.stack(acts,1), torch.stack(logps,1), torch.stack(ents,1), torch.stack(vals,1)

Ключевые моменты: компактные эмбеддинги и GRU, эмбеддинг стадии для условного поведения, смешанный словарь цифровых токенов и инструментов.

Цикл обучения и обновление RL

Обучение использует A2C-подобный градиент с энтропийной регуляризацией. Ниже — батчинг, паддинг, оценка траекторий и шаг обновления:

env=ToolEnv(); net=ActorCritic(V).to(device)
opt=torch.optim.Adam(net.parameters(),lr=3e-4)
def pad_batch(ctxs):
   L=max(len(c)+1 for c in ctxs)
   out=torch.full((len(ctxs),L),EOS,dtype=torch.long,device=device)
   for i,c in enumerate(ctxs): out[i,:len(c)+1]=torch.tensor(c+[CTX],device=device)
   return out
def run_batch(stage,batch=128,train=True,greedy=False):
   ctxs=[]; metas=[]
   for _ in range(batch):
       c,t,abc=env.sample(stage); ctxs.append(c); metas.append((t,abc))
   ctx=pad_batch(ctxs); stage_t=torch.full((batch,),stage,device=device,dtype=torch.long)
   acts,logps,ents,vals=net(ctx,stage_t,max_len=6,greedy=greedy)
   rewards=[]
   for i in range(batch):
       traj = acts[i].tolist()
       abc = metas[i][1]
       r,_ = env.step_seq(traj,abc,stage)
       rewards.append(r)
   R=torch.tensor(rewards,device=device).float()
   adv=(R-vals.sum(1)).detach()
   if not train: return R.mean().item(), 0.0
   pg=-(logps.sum(1)*adv).mean(); vloss=F.mse_loss(vals.sum(1),R); ent=-ents.mean()
   loss=pg+0.5*vloss+0.01*ent
   opt.zero_grad(); loss.backward(); nn.utils.clip_grad_norm_(net.parameters(),1.0); opt.step()
   return R.mean().item(), loss.item()

Паддинг гарантирует, что переменной длины контексты подаются как тензор фиксированного размера. Комбинация политической и value-ошибки плюс энтропия стимулирует исследование.

Курс обучения и инспекция результатов

Простая учебная стратегия переводит задачи от лёгких к сложным, а затем проводится периодическая оценка на всех стадиях:

print("Training…")
stages=[0,0,0,1,1,2]
for ep in range(1,61):
   stage=stages[min((ep-1)//10,len(stages)-1)]
   acc,loss=run_batch(stage,batch=192,train=True)
   if ep%5==0:
       with torch.no_grad():
           evals=[run_batch(s,train=False,greedy=True)[0] for s in [0,1,2]]
       print(f"ep={ep:02d} stage={stage} acc={acc:.3f} | eval T0={evals[0]:.3f} "
             f"T1={evals[1]:.3f} T2={evals[2]:.3f} loss={loss:.3f}")

Наблюдается, что простые стадии сходятся первыми, а более сложные улучшаются по мере того, как модель находит способы хранить промежуточные результаты и комбинировать операции.

Пробные прогонки и выводы

Инспекция жадных траекторий и вывод точности на тестовых батчах помогают понять, как модель кодирует планы и память:

def explain(stage):
   c,t,abc=env.sample(stage)
   ctx=pad_batch([c]); stage_t=torch.tensor([stage],device=device)
   with torch.no_grad(): a,_,_,_=net(ctx,stage_t,greedy=True)
   seq=[tok2str[x] for x in a[0].tolist()]
   r,_=env.step_seq(a[0].tolist(),abc,stage)
   return dict(stage=stage,ctx=c,target=t,actions=" ".join(seq),reward=round(float(r),2))
with torch.no_grad():
   for s in [0,1,2]:
       print(f"\nStage {s} samples:")
       for _ in range(5): print(explain(s))
with torch.no_grad():
   finals=[run_batch(s,train=False,greedy=True,batch=1000)[0] for s in [0,1,2]]
print(f"\nFinal greedy accuracies → T0={finals[0]:.3f}, T1={finals[1]:.3f}, T2={finals[2]:.3f}")

Выводы: даже компактная нейросеть может выработать внутренние механизмы планирования и использования инструментов при правильно поставленной задаче, функции награды и куррикулуме.

🇬🇧

Switch Language

Read this article in English

Switch to English