Um Guia de Codificação para Implementar Solucionadores Avançados de Equações Diferenciais, Simulações Estocásticas e Equações Diferenciais Ordinárias Neurais Usando Diffrax e JAX
Neste tutorial, exploramos como resolver equações diferenciais e construir modelos de equações diferenciais neurais usando a biblioteca Diffrax. Começamos configurando um ambiente computacional limpo e instalando as bibliotecas de computação científica necessárias, como JAX, Diffrax, Equinox e Optax. Em seguida, demonstramos como resolver equações diferenciais ordinárias usando solucionadores adaptativos […] A po
Neste tutorial, exploramos como resolver equações diferenciais e construir modelos de equações diferenciais neurais usando a biblioteca Diffrax. Começamos configurando um ambiente computacional limpo e instalando as bibliotecas de computação científica necessárias, como JAX, Diffrax, Equinox e Optax. Em seguida, demonstramos como resolver equações diferenciais ordinárias usando solucionadores adaptativos e realizar interpolação densa para consultar soluções em pontos de tempo arbitrários. Conforme avançamos, investigamos capacidades mais avançadas do Diffrax, incluindo a resolução de sistemas dinâmicos clássicos, o trabalho com estados baseados em PyTree e a execução de simulações em lote usando os recursos de vetorização do JAX. Também simulamos equações diferenciais estocásticas e geramos dados de um sistema dinâmico que serão usados posteriormente para treinar um modelo de equação diferencial ordinária neural. Copiar Código Copiado Usar um navegador diferente import os, sys, subprocess, importlib, pathlib SENTINEL = "/tmp/diffrax_colab_ready_v3" def _run(cmd): subprocess.check_call(cmd) def _need_install(): try: import numpy import jax import diffrax import equinox import optax import matplotlib return False except Exception: return True if not os.path.exists(SENTINEL) or _need_install(): _run([sys.executable, "-m", "pip", "uninstall", "-y", "numpy", "jax", "jaxlib", "diffrax", "equinox", "optax"]) _run([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip"]) _run([ sys.executable, "-m", "pip", "install", "-q", "numpy==1.26.4", "jax[cpu]==0.4.38", "jaxlib==0.4.38", "diffrax", "equinox", "optax", "matplotlib" ]) pathlib.Path(SENTINEL).write_text("ready") print("Pacotes instalados. O tempo de execução será reiniciado agora. Após reconectar, execute esta mesma célula novamente.") os._exit(0) import time import math import numpy as np import jax import jax.numpy as jnp import jax.random as jr import diffrax import equinox as eqx import optax import matplotlib.pyplot as plt print("NumPy:", np.version) print("JAX:", jax.version) print("Backend:", jax.default_backend()) def logistic(t, y, args): r, k = args return r * y * (1 - y / k) t0, t1 = 0.0, 10.0 ts = jnp.linspace(t0, t1, 300) y0 = jnp.array(0.4) args = (2.0, 5.0) sol_logistic = diffrax.diffeqsolve( diffrax.ODETerm(logistic), diffrax.Tsit5(), t0=t0, t1=t1, dt0=0.05, y0=y0, args=args, saveat=diffrax.SaveAt(ts=ts, dense=True), stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8), max_steps=100000, ) query_ts = jn
