Train a Model-Native Agent to Internalize Planning, Memory and Tool Use with End-to-End RL
'A compact neural agent learns to plan, store and compose symbolic tools end-to-end with reinforcement learning, demonstrating emergent multi-step reasoning on synthetic arithmetic tasks.'
Task and symbolic tools
This tutorial shows how to build a compact model-native agent that internalizes planning, short-term memory and multi-tool composition within a single neural network. The setup is a small synthetic arithmetic world where discrete actions act as symbolic "tools" (multiply, add, subtract, store, recall, answer) and the agent must plan a sequence of these tools to reach the correct result. A curriculum of increasingly complex stages encourages the agent to discover multi-step reasoning strategies and internal memory use via end-to-end reinforcement learning.
Environment and symbolic tool semantics
The environment implements sampling of contexts and a step evaluator that interprets tool sequences and returns shaped rewards. Below is the environment and token setup used in experiments:
import math, random, torch, torch.nn as nn, torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"; torch.manual_seed(0); random.seed(0)
V = 18; CTX = 10; MUL, ADD, SUB, ANS, STO, RCL, EOS = 11, 12, 13, 14, 15, 16, 17
tok2str = {**{i: str(i) for i in range(10)}, CTX:"[CTX]", MUL:"[MUL]", ADD:"[ADD]", SUB:"[SUB]", ANS:"[ANS]", STO:"[STO]", RCL:"[RCL]", EOS:"[EOS]"}
class ToolEnv:
def __init__(self, max_steps=7):
self.max_steps = max_steps
def sample(self, stage):
a,b,c,d,e = [random.randint(0,9) for _ in range(5)]
if stage==0: ctx=[a,b,c]; target=a*b+c
elif stage==1: ctx=[a,b,c,d]; target=(a*b+c)-d
else: ctx=[a,b,c,d,e]; target=(a*b+c)-(d*e)
return ctx, target, (a,b,c,d,e)
def step_seq(self, actions, abc, stage):
a,b,c,d,e = abc; last=None; mem=None; steps=0; shaped=0.0
goal0=a*b; goal1=goal0+c; goal2=goal1-d; goal3=d*e; goal4=goal1-goal3
for act in actions:
steps+=1
if act==MUL: last=(a*b if last is None else last*(d if stage>0 else 1))
elif act==ADD and last is not None: last+=c
elif act==SUB and last is not None:
last -= (e if stage==2 and mem=="use_d" else (d if stage>0 else 0))
elif act==STO: mem="use_d" if stage>=1 else "ok"
elif act==RCL and mem is not None:
last = (d*e) if (stage==2 and mem=="use_d") else (last if last else 0)
elif act==ANS:
target=[goal0,goal1,goal2,goal4][stage] if stage==2 else [goal0,goal1,goal2][stage]
correct=(last==target)
if stage==0: shaped += 0.25*(last==goal0)+0.5*(last==goal1)
if stage==1: shaped += 0.25*(last==goal0)+0.5*(last==goal1)+0.75*(last==goal2)
if stage==2: shaped += 0.2*(last==goal0)+0.4*(last==goal1)+0.6*(last==goal4)+0.6*(last==goal3)
return (1.0 if correct else 0.0)+0.2*shaped, steps
if steps>=self.max_steps: break
return 0.0, stepsThis environment defines three stages of increasing difficulty: single-step composition, a subtraction after composition, and a two-term composition requiring intermediate storage and recall. The reward function provides a sparse binary correctness signal plus shaped bonuses that nudge learning toward intermediate subgoals.
Model architecture: a stage-aware actor-critic
The policy is an actor-critic built around an embedding layer and a small GRU. Stage embeddings are added to the context embedding so the same network can adapt behavior depending on task complexity. The forward pass emits a fixed-length sequence of discrete tool tokens.
class ActorCritic(nn.Module):
def __init__(self,V,d=96,nstage=3):
super().__init__()
self.emb=nn.Embedding(V,d); self.stage_emb=nn.Embedding(nstage,d)
self.rnn=nn.GRU(d,d,1,batch_first=True); self.pi=nn.Linear(d,V); self.v=nn.Linear(d,1)
def forward(self,ctx,stage,max_len=6,greedy=False):
B=ctx.shape[0]; ce=self.emb(ctx).mean(1)+self.stage_emb(stage).unsqueeze(1)
h=torch.tanh(ce.mean(1)).unsqueeze(0); inp=self.emb(torch.full((B,1),CTX,device=device))
acts,logps,ents,vals=[],[],[],[]
for _ in range(max_len):
out,h=self.rnn(inp,h); val=self.v(out[:,-1]); logits=self.pi(out[:,-1])
pi=F.log_softmax(logits,dim=-1).exp(); ent=-(pi*torch.log(pi+1e-9)).sum(1)
a=torch.argmax(logits,1) if greedy else torch.distributions.Categorical(pi).sample()
logp=F.log_softmax(logits,dim=-1).gather(1,a.unsqueeze(1)).squeeze(1)
inp=self.emb(a.unsqueeze(1))
acts.append(a); logps.append(logp); ents.append(ent); vals.append(val.squeeze(1))
return torch.stack(acts,1), torch.stack(logps,1), torch.stack(ents,1), torch.stack(vals,1)Key design choices: small embedding and GRU to keep the model compact, explicit stage embedding so the same weights can handle multiple complexities, and an action vocabulary that mixes numeric tokens and tool tokens.
Training loop, batch handling and RL update
Training uses an advantage actor-critic (A2C)-style update with entropy regularization. The code below shows batching, padding, rollout evaluation and the update step.
env=ToolEnv(); net=ActorCritic(V).to(device)
opt=torch.optim.Adam(net.parameters(),lr=3e-4)
def pad_batch(ctxs):
L=max(len(c)+1 for c in ctxs)
out=torch.full((len(ctxs),L),EOS,dtype=torch.long,device=device)
for i,c in enumerate(ctxs): out[i,:len(c)+1]=torch.tensor(c+[CTX],device=device)
return out
def run_batch(stage,batch=128,train=True,greedy=False):
ctxs=[]; metas=[]
for _ in range(batch):
c,t,abc=env.sample(stage); ctxs.append(c); metas.append((t,abc))
ctx=pad_batch(ctxs); stage_t=torch.full((batch,),stage,device=device,dtype=torch.long)
acts,logps,ents,vals=net(ctx,stage_t,max_len=6,greedy=greedy)
rewards=[]
for i in range(batch):
traj = acts[i].tolist()
abc = metas[i][1]
r,_ = env.step_seq(traj,abc,stage)
rewards.append(r)
R=torch.tensor(rewards,device=device).float()
adv=(R-vals.sum(1)).detach()
if not train: return R.mean().item(), 0.0
pg=-(logps.sum(1)*adv).mean(); vloss=F.mse_loss(vals.sum(1),R); ent=-ents.mean()
loss=pg+0.5*vloss+0.01*ent
opt.zero_grad(); loss.backward(); nn.utils.clip_grad_norm_(net.parameters(),1.0); opt.step()
return R.mean().item(), loss.item()The update combines policy gradient loss, value loss and entropy regularization. Padding ensures variable-length numeric contexts are supplied as a fixed-sized tensor to the model.
Curriculum training and observed learning dynamics
A simple curriculum schedules stages from easy to hard so the agent bootstraps from short reasoning chains to longer, compositional behaviors. Training and periodic evaluation is done as shown below:
print("Training…")
stages=[0,0,0,1,1,2]
for ep in range(1,61):
stage=stages[min((ep-1)//10,len(stages)-1)]
acc,loss=run_batch(stage,batch=192,train=True)
if ep%5==0:
with torch.no_grad():
evals=[run_batch(s,train=False,greedy=True)[0] for s in [0,1,2]]
print(f"ep={ep:02d} stage={stage} acc={acc:.3f} | eval T0={evals[0]:.3f} "
f"T1={evals[1]:.3f} T2={evals[2]:.3f} loss={loss:.3f}")Over training you can watch accuracy on each stage improve, with the easiest stage converging first and the harder stages improving as the agent learns to compose tools and use memory.
Inspecting trajectories and final evaluation
To understand what the model has learned, we probe greedy rollouts and print action token sequences, reward, target and context examples.
def explain(stage):
c,t,abc=env.sample(stage)
ctx=pad_batch([c]); stage_t=torch.tensor([stage],device=device)
with torch.no_grad(): a,_,_,_=net(ctx,stage_t,greedy=True)
seq=[tok2str[x] for x in a[0].tolist()]
r,_=env.step_seq(a[0].tolist(),abc,stage)
return dict(stage=stage,ctx=c,target=t,actions=" ".join(seq),reward=round(float(r),2))
with torch.no_grad():
for s in [0,1,2]:
print(f"\nStage {s} samples:")
for _ in range(5): print(explain(s))
with torch.no_grad():
finals=[run_batch(s,train=False,greedy=True,batch=1000)[0] for s in [0,1,2]]
print(f"\nFinal greedy accuracies → T0={finals[0]:.3f}, T1={finals[1]:.3f}, T2={finals[2]:.3f}")These inspection steps reveal emergent patterns: the model organizes token sequences into effective plans, sometimes storing intermediate products and recalling them later to complete multi-term computations.
Practical takeaways
- Small, stage-aware recurrent actor-critic models can discover internal planning and memory usage when trained end-to-end with appropriate shaped rewards and a curriculum.
- Explicit symbolic tool tokens and a shaped reward that acknowledges intermediate subgoals help the agent to discover compositional programs inside neural dynamics.
- This approach reduces reliance on external orchestration, enabling a single learned model to both plan and act.
Feel free to run the full code to reproduce the experiments and inspect trajectories on your machine.
Сменить язык
Читать эту статью на русском