Напиши один раз — запускай везде: Ivy для фреймворк-агностичного ML, транспиляции и бенчмарков

Объединение ML-разработки с Ivy

Ivy позволяет писать код глубокого обучения один раз и запускать его на NumPy, PyTorch, TensorFlow и JAX без изменений в реализации. Ниже — практические демонстрации: полностью фреймворк-агностичная нейронная сеть, примеры транспиляции и единый API, продвинутые возможности вроде Containers и трассировки, а также бенчмарки производительности.

Пример фреймворк-агностичной нейронной сети

Простая нейронная сеть, написанная только с использованием Ivy-операций, работает на нескольких бэкендах. Ниже показаны установка окружения, импорты и реализация модели.

!pip install -q ivy tensorflow torch jax jaxlib

import ivy
import numpy as np
import time


print(f"Ivy version: {ivy.__version__}")



class IvyNeuralNetwork:
   """A simple neural network written purely in Ivy that works with any backend."""
  
   def __init__(self, input_dim=4, hidden_dim=8, output_dim=3):
       self.w1 = ivy.random_uniform(shape=(input_dim, hidden_dim), low=-0.5, high=0.5)
       self.b1 = ivy.zeros((hidden_dim,))
       self.w2 = ivy.random_uniform(shape=(hidden_dim, output_dim), low=-0.5, high=0.5)
       self.b2 = ivy.zeros((output_dim,))
      
   def forward(self, x):
       """Forward pass using pure Ivy operations."""
       h = ivy.matmul(x, self.w1) + self.b1
       h = ivy.relu(h)
      
       out = ivy.matmul(h, self.w2) + self.b2
       return ivy.softmax(out)
  
   def train_step(self, x, y, lr=0.01):
       """Simple training step with manual gradients."""
       pred = self.forward(x)
      
       loss = -ivy.mean(ivy.sum(y * ivy.log(pred + 1e-8), axis=-1))
      
       pred_error = pred - y
      
       h_activated = ivy.relu(ivy.matmul(x, self.w1) + self.b1)
       h_t = ivy.permute_dims(h_activated, axes=(1, 0))
       dw2 = ivy.matmul(h_t, pred_error) / x.shape[0]
       db2 = ivy.mean(pred_error, axis=0)
      
       self.w2 = self.w2 - lr * dw2
       self.b2 = self.b2 - lr * db2
      
       return loss



def demo_framework_agnostic_network():
   """Demonstrate the same network running on different backends."""
   print("\n" + "="*70)
   print("PART 1: Framework-Agnostic Neural Network")
   print("="*70)
  
   X = np.random.randn(100, 4).astype(np.float32)
   y = np.eye(3)[np.random.randint(0, 3, 100)].astype(np.float32)
  
   backends = ['numpy', 'torch', 'tensorflow', 'jax']
   results = {}
  
   for backend in backends:
       try:
           ivy.set_backend(backend)
          
           if backend == 'jax':
               import jax
               jax.config.update('jax_enable_x64', True)
          
           print(f"\n Running with {backend.upper()} backend...")
          
           X_ivy = ivy.array(X)
           y_ivy = ivy.array(y)
          
           net = IvyNeuralNetwork()
          
           start_time = time.time()
           for epoch in range(50):
               loss = net.train_step(X_ivy, y_ivy, lr=0.1)
          
           elapsed = time.time() - start_time
          
           predictions = net.forward(X_ivy)
           accuracy = ivy.mean(
               ivy.astype(ivy.argmax(predictions, axis=-1) == ivy.argmax(y_ivy, axis=-1), 'float32')
           )
          
           results[backend] = {
               'loss': float(ivy.to_numpy(loss)),
               'accuracy': float(ivy.to_numpy(accuracy)),
               'time': elapsed
           }
          
           print(f"   Final Loss: {results[backend]['loss']:.4f}")
           print(f"   Accuracy: {results[backend]['accuracy']:.2%}")
           print(f"   Time: {results[backend]['time']:.3f}s")
          
       except Exception as e:
           print(f"    {backend} error: {str(e)[:80]}")
           results[backend] = None
  
   ivy.unset_backend()
   return results

Демонстрация показывает обучение сети, собранной с использованием Ivy, на NumPy, PyTorch, TensorFlow и JAX, а также фиксирует метрики — loss, accuracy и время выполнения для каждого бэкенда.

Транспиляция и воспроизведение вычислений в разных фреймворках

Ivy поддерживает транспиляционные сценарии, но чаще используется единый API для того, чтобы одну и ту же функцию можно было оценивать на разных бэкендах. Ниже — демонстрационная функция из туториала.


def demo_transpilation():
   """Demonstrate transpiling code from PyTorch to TensorFlow and JAX."""
   print("\n" + "="*70)
   print("PART 2: Framework Transpilation")
   print("="*70)
  
   try:
       import torch
       import tensorflow as tf
      
       def pytorch_computation(x):
           """A simple PyTorch computation."""
           return torch.mean(torch.relu(x * 2.0 + 1.0))
      
       x_torch = torch.randn(10, 5)
      
       print("\n Original PyTorch function:")
       result_torch = pytorch_computation(x_torch)
       print(f"   PyTorch result: {result_torch.item():.6f}")
      
       print("\n Transpilation Demo:")
       print("   Note: ivy.transpile() is powerful but complex.")
       print("   It works best with traced/compiled functions.")
       print("   For simple demonstrations, we'll show the unified API instead.")
      
       print("\n Equivalent computation across frameworks:")
       x_np = x_torch.numpy()
      
       ivy.set_backend('numpy')
       x_ivy = ivy.array(x_np)
       result_np = ivy.mean(ivy.relu(x_ivy * 2.0 + 1.0))
       print(f"   NumPy result: {float(ivy.to_numpy(result_np)):.6f}")
      
       ivy.set_backend('tensorflow')
       x_ivy = ivy.array(x_np)
       result_tf = ivy.mean(ivy.relu(x_ivy * 2.0 + 1.0))
       print(f"   TensorFlow result: {float(ivy.to_numpy(result_tf)):.6f}")
      
       ivy.set_backend('jax')
       import jax
       jax.config.update('jax_enable_x64', True)
       x_ivy = ivy.array(x_np)
       result_jax = ivy.mean(ivy.relu(x_ivy * 2.0 + 1.0))
       print(f"   JAX result: {float(ivy.to_numpy(result_jax)):.6f}")
      
       print(f"\n    All results match within numerical precision!")
      
       ivy.unset_backend()
          
   except Exception as e:
       print(f" Demo error: {str(e)[:80]}")

Вместо того чтобы пытаться транспилировать любые произвольные функции, туториал показывает, как единый API Ivy позволяет воспроизвести вычисления и получить согласованные результаты в разных экосистемах.

Примеры единого API

Ivy предоставляет единый интерфейс операций, так что один и тот же код выполняется одинаково на поддерживаемых бэкендах. Функция ниже демонстрирует набор операций, выполненных на нескольких бэкендах.


def demo_unified_api():
   """Show how Ivy's unified API works across different operations."""
   print("\n" + "="*70)
   print("PART 3: Unified API Across Frameworks")
   print("="*70)
  
   operations = [
       ("Matrix Multiplication", lambda x: ivy.matmul(x, ivy.permute_dims(x, axes=(1, 0)))),
       ("Element-wise Operations", lambda x: ivy.add(ivy.multiply(x, x), 2)),
       ("Reductions", lambda x: ivy.mean(ivy.sum(x, axis=0))),
       ("Neural Net Ops", lambda x: ivy.mean(ivy.relu(x))),
       ("Statistical Ops", lambda x: ivy.std(x)),
       ("Broadcasting", lambda x: ivy.multiply(x, ivy.array([1.0, 2.0, 3.0, 4.0]))),
   ]
  
   X = np.random.randn(5, 4).astype(np.float32)
  
   for op_name, op_func in operations:
       print(f"\n {op_name}:")
      
       for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
           try:
               ivy.set_backend(backend)
              
               if backend == 'jax':
                   import jax
                   jax.config.update('jax_enable_x64', True)
              
               x_ivy = ivy.array(X)
               result = op_func(x_ivy)
               result_np = ivy.to_numpy(result)
              
               if result_np.shape == ():
                   print(f"   {backend:12s}: scalar value = {float(result_np):.4f}")
               else:
                   print(f"   {backend:12s}: shape={result_np.shape}, mean={np.mean(result_np):.4f}")
              
           except Exception as e:
               print(f"   {backend:12s}:  {str(e)[:60]}")
      
       ivy.unset_backend()

Эти примеры подтверждают, что матричные операции, поэлементные вычисления, редукции и трансляции работают согласованно при использовании Ivy.

Продвинутые возможности: Containers, соответствие Array API, сложные графы

Ivy предоставляет вложенные структуры (ivy.Container), соблюдает стиль Array API и позволяет связывать операции в сложные графы. Пример ниже демонстрирует работу с контейнерами и цепочкой операций.


def demo_advanced_features():
   """Demonstrate advanced Ivy features."""
   print("\n" + "="*70)
   print("PART 4: Advanced Ivy Features")
   print("="*70)
  
   print("\n Ivy Containers - Nested Data Structures:")
   try:
       ivy.set_backend('torch')
      
       container = ivy.Container({
           'layer1': {'weights': ivy.random_uniform(shape=(4, 8)), 'bias': ivy.zeros((8,))},
           'layer2': {'weights': ivy.random_uniform(shape=(8, 3)), 'bias': ivy.zeros((3,))}
       })
      
       print(f"   Container keys: {list(container.keys())}")
       print(f"   Layer1 weight shape: {container['layer1']['weights'].shape}")
       print(f"   Layer2 bias shape: {container['layer2']['bias'].shape}")
      
       def scale_fn(x, _):
           return x * 2.0
      
       scaled_container = container.cont_map(scale_fn)
       print(f"    Applied scaling to all tensors in container")
      
   except Exception as e:
       print(f"    Container demo: {str(e)[:80]}")
  
   print("\n Array API Standard Compliance:")
   backends_tested = []
   for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
       try:
           ivy.set_backend(backend)
          
           if backend == 'jax':
               import jax
               jax.config.update('jax_enable_x64', True)
          
           x = ivy.array([1.0, 2.0, 3.0])
           y = ivy.array([4.0, 5.0, 6.0])
          
           result = ivy.sqrt(ivy.square(x) + ivy.square(y))
           print(f"   {backend:12s}: L2 norm operations work ")
           backends_tested.append(backend)
       except Exception as e:
           print(f"   {backend:12s}: {str(e)[:50]}")
  
   print(f"\n   Successfully tested {len(backends_tested)} backends")
  
   print("\n Complex Multi-step Operations:")
   try:
       ivy.set_backend('torch')
      
       x = ivy.random_uniform(shape=(10, 5), low=0, high=1)
      
       result = ivy.mean(
           ivy.relu(
               ivy.matmul(x, ivy.permute_dims(x, axes=(1, 0)))
           ),
           axis=0
       )
      
       print(f"   Chained operations (matmul → relu → mean)")
       print(f"   Input shape: (10, 5), Output shape: {result.shape}")
       print(f"    Complex operation graph executed successfully")
      
   except Exception as e:
       print(f"    {str(e)[:80]}")
  
   ivy.unset_backend()

Эти возможности помогают структурировать параметры моделей, гарантировать совместимость API и выполнять сложные последовательности операций на разных бэкендах.

Бенчмаркинг производительности

Ivy также позволяет сравнивать производительность одинаковых нагрузок на разных фреймворках, чтобы выбрать оптимальный бэкенд. Ниже — утилиты для бенчмарка и демонстрация производительности.


def benchmark_operation(op_func, x, iterations=50):
   """Benchmark an operation."""
   start = time.time()
   for _ in range(iterations):
       result = op_func(x)
   return time.time() - start



def demo_performance():
   """Compare performance across backends."""
   print("\n" + "="*70)
   print("PART 5: Performance Benchmarking")
   print("="*70)
  
   X = np.random.randn(100, 100).astype(np.float32)
  
   def complex_operation(x):
       """A more complex computation."""
       z = ivy.matmul(x, ivy.permute_dims(x, axes=(1, 0)))
       z = ivy.relu(z)
       z = ivy.mean(z, axis=0)
       return ivy.sum(z)
  
   print("\n Benchmarking matrix operations (50 iterations):")
   print("   Operation: matmul → relu → mean → sum")
  
   for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
       try:
           ivy.set_backend(backend)
          
           if backend == 'jax':
               import jax
               jax.config.update('jax_enable_x64', True)
          
           x_ivy = ivy.array(X)
          
           _ = complex_operation(x_ivy)
          
           elapsed = benchmark_operation(complex_operation, x_ivy, iterations=50)
          
           print(f"   {backend:12s}: {elapsed:.4f}s ({elapsed/50*1000:.2f}ms per op)")
          
       except Exception as e:
           print(f"   {backend:12s}:  {str(e)[:60]}")
  
   ivy.unset_backend()

Запуск полного скрипта

Скрипт туториала объединяет все демонстрации: обучение на нескольких бэкендах, примеры транспиляции/API, продвинутые возможности и бенчмарки.


if __name__ == "__main__":
   print("""
   ╔════════════════════════════════════════════════════════════════════╗
   ║          Advanced Ivy Tutorial - Framework-Agnostic ML             ║
   ║                  Write Once, Run Everywhere!                       ║
   ╚════════════════════════════════════════════════════════════════════╝
   """)
  
   results = demo_framework_agnostic_network()
   demo_transpilation()
   demo_unified_api()
   demo_advanced_features()
   demo_performance()
  
   print("\n" + "="*70)
   print(" Tutorial Complete!")
   print("="*70)
   print("\n Key Takeaways:")
   print("   1. Ivy enables writing ML code once that runs on any framework")
   print("   2. Same operations work identically across NumPy, PyTorch, TF, JAX")
   print("   3. Unified API provides consistent operations across backends")
   print("   4. Switch backends dynamically for optimal performance")
   print("   5. Containers help manage complex nested model structures")
   print("\n Next Steps:")
   print("   - Build your own framework-agnostic models")
   print("   - Use ivy.Container for managing model parameters")
   print("   - Explore ivy.trace_graph() for computation graph optimization")
   print("   - Try different backends to find optimal performance")
   print("   - Check docs at: https://docs.ivy.dev/")
   print("="*70)

Этот материал показывает, как Ivy абстрагирует отличия фреймворков, сохраняя производительность и числовую повторяемость, что делает переносимый ML-код реальностью.