Write Once, Run Everywhere: Ivy for Framework-Agnostic ML, Transpile and Benchmark

Unifying Machine Learning with Ivy

Ivy makes it possible to write deep learning code once and run it across NumPy, PyTorch, TensorFlow, and JAX without changing the implementation. Below are practical demonstrations: a fully framework-agnostic neural network, examples of transpilation and unified APIs, advanced features like Containers and tracing, and performance benchmarks.

Framework-agnostic neural network example

A simple neural network written purely with Ivy operations runs seamlessly on multiple backends. The example below shows environment setup, basic imports, and the model implementation.

!pip install -q ivy tensorflow torch jax jaxlib

import ivy
import numpy as np
import time


print(f"Ivy version: {ivy.__version__}")



class IvyNeuralNetwork:
   """A simple neural network written purely in Ivy that works with any backend."""
  
   def __init__(self, input_dim=4, hidden_dim=8, output_dim=3):
       self.w1 = ivy.random_uniform(shape=(input_dim, hidden_dim), low=-0.5, high=0.5)
       self.b1 = ivy.zeros((hidden_dim,))
       self.w2 = ivy.random_uniform(shape=(hidden_dim, output_dim), low=-0.5, high=0.5)
       self.b2 = ivy.zeros((output_dim,))
      
   def forward(self, x):
       """Forward pass using pure Ivy operations."""
       h = ivy.matmul(x, self.w1) + self.b1
       h = ivy.relu(h)
      
       out = ivy.matmul(h, self.w2) + self.b2
       return ivy.softmax(out)
  
   def train_step(self, x, y, lr=0.01):
       """Simple training step with manual gradients."""
       pred = self.forward(x)
      
       loss = -ivy.mean(ivy.sum(y * ivy.log(pred + 1e-8), axis=-1))
      
       pred_error = pred - y
      
       h_activated = ivy.relu(ivy.matmul(x, self.w1) + self.b1)
       h_t = ivy.permute_dims(h_activated, axes=(1, 0))
       dw2 = ivy.matmul(h_t, pred_error) / x.shape[0]
       db2 = ivy.mean(pred_error, axis=0)
      
       self.w2 = self.w2 - lr * dw2
       self.b2 = self.b2 - lr * db2
      
       return loss



def demo_framework_agnostic_network():
   """Demonstrate the same network running on different backends."""
   print("\n" + "="*70)
   print("PART 1: Framework-Agnostic Neural Network")
   print("="*70)
  
   X = np.random.randn(100, 4).astype(np.float32)
   y = np.eye(3)[np.random.randint(0, 3, 100)].astype(np.float32)
  
   backends = ['numpy', 'torch', 'tensorflow', 'jax']
   results = {}
  
   for backend in backends:
       try:
           ivy.set_backend(backend)
          
           if backend == 'jax':
               import jax
               jax.config.update('jax_enable_x64', True)
          
           print(f"\n Running with {backend.upper()} backend...")
          
           X_ivy = ivy.array(X)
           y_ivy = ivy.array(y)
          
           net = IvyNeuralNetwork()
          
           start_time = time.time()
           for epoch in range(50):
               loss = net.train_step(X_ivy, y_ivy, lr=0.1)
          
           elapsed = time.time() - start_time
          
           predictions = net.forward(X_ivy)
           accuracy = ivy.mean(
               ivy.astype(ivy.argmax(predictions, axis=-1) == ivy.argmax(y_ivy, axis=-1), 'float32')
           )
          
           results[backend] = {
               'loss': float(ivy.to_numpy(loss)),
               'accuracy': float(ivy.to_numpy(accuracy)),
               'time': elapsed
           }
          
           print(f"   Final Loss: {results[backend]['loss']:.4f}")
           print(f"   Accuracy: {results[backend]['accuracy']:.2%}")
           print(f"   Time: {results[backend]['time']:.3f}s")
          
       except Exception as e:
           print(f"    {backend} error: {str(e)[:80]}")
           results[backend] = None
  
   ivy.unset_backend()
   return results

The demo builds and trains the network entirely with Ivy and then runs it on NumPy, PyTorch, TensorFlow, and JAX. Results show loss, accuracy, and elapsed time per backend—highlighting Ivy’s portability.

Transpilation and reproducing computations across frameworks

Ivy supports transpilation workflows and, more commonly, provides a unified API so the same computation can be expressed once and evaluated on different backends. Below is the demonstration function used in the tutorial.


def demo_transpilation():
   """Demonstrate transpiling code from PyTorch to TensorFlow and JAX."""
   print("\n" + "="*70)
   print("PART 2: Framework Transpilation")
   print("="*70)
  
   try:
       import torch
       import tensorflow as tf
      
       def pytorch_computation(x):
           """A simple PyTorch computation."""
           return torch.mean(torch.relu(x * 2.0 + 1.0))
      
       x_torch = torch.randn(10, 5)
      
       print("\n Original PyTorch function:")
       result_torch = pytorch_computation(x_torch)
       print(f"   PyTorch result: {result_torch.item():.6f}")
      
       print("\n Transpilation Demo:")
       print("   Note: ivy.transpile() is powerful but complex.")
       print("   It works best with traced/compiled functions.")
       print("   For simple demonstrations, we'll show the unified API instead.")
      
       print("\n Equivalent computation across frameworks:")
       x_np = x_torch.numpy()
      
       ivy.set_backend('numpy')
       x_ivy = ivy.array(x_np)
       result_np = ivy.mean(ivy.relu(x_ivy * 2.0 + 1.0))
       print(f"   NumPy result: {float(ivy.to_numpy(result_np)):.6f}")
      
       ivy.set_backend('tensorflow')
       x_ivy = ivy.array(x_np)
       result_tf = ivy.mean(ivy.relu(x_ivy * 2.0 + 1.0))
       print(f"   TensorFlow result: {float(ivy.to_numpy(result_tf)):.6f}")
      
       ivy.set_backend('jax')
       import jax
       jax.config.update('jax_enable_x64', True)
       x_ivy = ivy.array(x_np)
       result_jax = ivy.mean(ivy.relu(x_ivy * 2.0 + 1.0))
       print(f"   JAX result: {float(ivy.to_numpy(result_jax)):.6f}")
      
       print(f"\n    All results match within numerical precision!")
      
       ivy.unset_backend()
          
   except Exception as e:
       print(f" Demo error: {str(e)[:80]}")

Rather than directly transpiling every function, the tutorial emphasizes using Ivy’s unified API to reproduce computations across different backends and confirms results agree up to numerical precision.

Examples of the unified API

Ivy exposes a consistent set of operations so the same code runs across supported frameworks. The following function shows a range of operations executed across backends.


def demo_unified_api():
   """Show how Ivy's unified API works across different operations."""
   print("\n" + "="*70)
   print("PART 3: Unified API Across Frameworks")
   print("="*70)
  
   operations = [
       ("Matrix Multiplication", lambda x: ivy.matmul(x, ivy.permute_dims(x, axes=(1, 0)))),
       ("Element-wise Operations", lambda x: ivy.add(ivy.multiply(x, x), 2)),
       ("Reductions", lambda x: ivy.mean(ivy.sum(x, axis=0))),
       ("Neural Net Ops", lambda x: ivy.mean(ivy.relu(x))),
       ("Statistical Ops", lambda x: ivy.std(x)),
       ("Broadcasting", lambda x: ivy.multiply(x, ivy.array([1.0, 2.0, 3.0, 4.0]))),
   ]
  
   X = np.random.randn(5, 4).astype(np.float32)
  
   for op_name, op_func in operations:
       print(f"\n {op_name}:")
      
       for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
           try:
               ivy.set_backend(backend)
              
               if backend == 'jax':
                   import jax
                   jax.config.update('jax_enable_x64', True)
              
               x_ivy = ivy.array(X)
               result = op_func(x_ivy)
               result_np = ivy.to_numpy(result)
              
               if result_np.shape == ():
                   print(f"   {backend:12s}: scalar value = {float(result_np):.4f}")
               else:
                   print(f"   {backend:12s}: shape={result_np.shape}, mean={np.mean(result_np):.4f}")
              
           except Exception as e:
               print(f"   {backend:12s}:  {str(e)[:60]}")
      
       ivy.unset_backend()

This demonstrates how matrix ops, element-wise math, reductions, and broadcasting behave consistently across backends when using Ivy.

Advanced features: Containers, API compliance, complex graphs

Ivy provides nested data structures (ivy.Container), Array API style guarantees, and the ability to chain operations into more complex graphs. The example below shows how containers and chained ops are exercised.


def demo_advanced_features():
   """Demonstrate advanced Ivy features."""
   print("\n" + "="*70)
   print("PART 4: Advanced Ivy Features")
   print("="*70)
  
   print("\n Ivy Containers - Nested Data Structures:")
   try:
       ivy.set_backend('torch')
      
       container = ivy.Container({
           'layer1': {'weights': ivy.random_uniform(shape=(4, 8)), 'bias': ivy.zeros((8,))},
           'layer2': {'weights': ivy.random_uniform(shape=(8, 3)), 'bias': ivy.zeros((3,))}
       })
      
       print(f"   Container keys: {list(container.keys())}")
       print(f"   Layer1 weight shape: {container['layer1']['weights'].shape}")
       print(f"   Layer2 bias shape: {container['layer2']['bias'].shape}")
      
       def scale_fn(x, _):
           return x * 2.0
      
       scaled_container = container.cont_map(scale_fn)
       print(f"    Applied scaling to all tensors in container")
      
   except Exception as e:
       print(f"    Container demo: {str(e)[:80]}")
  
   print("\n Array API Standard Compliance:")
   backends_tested = []
   for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
       try:
           ivy.set_backend(backend)
          
           if backend == 'jax':
               import jax
               jax.config.update('jax_enable_x64', True)
          
           x = ivy.array([1.0, 2.0, 3.0])
           y = ivy.array([4.0, 5.0, 6.0])
          
           result = ivy.sqrt(ivy.square(x) + ivy.square(y))
           print(f"   {backend:12s}: L2 norm operations work ")
           backends_tested.append(backend)
       except Exception as e:
           print(f"   {backend:12s}: {str(e)[:50]}")
  
   print(f"\n   Successfully tested {len(backends_tested)} backends")
  
   print("\n Complex Multi-step Operations:")
   try:
       ivy.set_backend('torch')
      
       x = ivy.random_uniform(shape=(10, 5), low=0, high=1)
      
       result = ivy.mean(
           ivy.relu(
               ivy.matmul(x, ivy.permute_dims(x, axes=(1, 0)))
           ),
           axis=0
       )
      
       print(f"   Chained operations (matmul → relu → mean)")
       print(f"   Input shape: (10, 5), Output shape: {result.shape}")
       print(f"    Complex operation graph executed successfully")
      
   except Exception as e:
       print(f"    {str(e)[:80]}")
  
   ivy.unset_backend()

These features make it easy to structure models, ensure API compatibility across ecosystems, and compose multi-step computations that can be executed on different backends.

Performance benchmarking across backends

Ivy also lets you benchmark identical workloads across frameworks, which helps determine the best backend for a particular workload. Below are benchmarking utilities and the performance demo.


def benchmark_operation(op_func, x, iterations=50):
   """Benchmark an operation."""
   start = time.time()
   for _ in range(iterations):
       result = op_func(x)
   return time.time() - start



def demo_performance():
   """Compare performance across backends."""
   print("\n" + "="*70)
   print("PART 5: Performance Benchmarking")
   print("="*70)
  
   X = np.random.randn(100, 100).astype(np.float32)
  
   def complex_operation(x):
       """A more complex computation."""
       z = ivy.matmul(x, ivy.permute_dims(x, axes=(1, 0)))
       z = ivy.relu(z)
       z = ivy.mean(z, axis=0)
       return ivy.sum(z)
  
   print("\n Benchmarking matrix operations (50 iterations):")
   print("   Operation: matmul → relu → mean → sum")
  
   for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
       try:
           ivy.set_backend(backend)
          
           if backend == 'jax':
               import jax
               jax.config.update('jax_enable_x64', True)
          
           x_ivy = ivy.array(X)
          
           _ = complex_operation(x_ivy)
          
           elapsed = benchmark_operation(complex_operation, x_ivy, iterations=50)
          
           print(f"   {backend:12s}: {elapsed:.4f}s ({elapsed/50*1000:.2f}ms per op)")
          
       except Exception as e:
           print(f"   {backend:12s}:  {str(e)[:60]}")
  
   ivy.unset_backend()

Running the full tutorial script

The tutorial script ties together all demos: multi-backend training, transpilation/unified API examples, advanced features, and benchmarking.


if __name__ == "__main__":
   print("""
   ╔════════════════════════════════════════════════════════════════════╗
   ║          Advanced Ivy Tutorial - Framework-Agnostic ML             ║
   ║                  Write Once, Run Everywhere!                       ║
   ╚════════════════════════════════════════════════════════════════════╝
   """)
  
   results = demo_framework_agnostic_network()
   demo_transpilation()
   demo_unified_api()
   demo_advanced_features()
   demo_performance()
  
   print("\n" + "="*70)
   print(" Tutorial Complete!")
   print("="*70)
   print("\n Key Takeaways:")
   print("   1. Ivy enables writing ML code once that runs on any framework")
   print("   2. Same operations work identically across NumPy, PyTorch, TF, JAX")
   print("   3. Unified API provides consistent operations across backends")
   print("   4. Switch backends dynamically for optimal performance")
   print("   5. Containers help manage complex nested model structures")
   print("\n Next Steps:")
   print("   - Build your own framework-agnostic models")
   print("   - Use ivy.Container for managing model parameters")
   print("   - Explore ivy.trace_graph() for computation graph optimization")
   print("   - Try different backends to find optimal performance")
   print("   - Check docs at: https://docs.ivy.dev/")
   print("="*70)

Together these examples show how Ivy abstracts framework differences while preserving performance and numerical behavior, enabling truly portable ML code.