Implementando Deep Q-Learning (DQN) do Zero Usando RLax JAX Haiku e Optax para Treinar um Agente de Reforço de Aprendizagem CartPole

Implementando Deep Q-Learning (DQN) do Zero Usando RLax JAX Haiku e Optax para Treinar um Agente de Reforço de Aprendizagem CartPole

Neste tutorial, implementamos um agente de aprendizado por reforço usando RLax, uma biblioteca orientada à pesquisa desenvolvida pelo Google DeepMind para construir algoritmos de aprendizado por reforço com JAX. Combinamos RLax com JAX, Haiku e Optax para construir um agente Deep Q-Learning (DQN) que aprende a resolver o ambiente CartPole. Em vez de usar um framework de RL totalmente empacotado, […] The post Implementando

Neste tutorial, implementamos um agente de aprendizado por reforço usando RLax, uma biblioteca orientada à pesquisa desenvolvida pelo Google DeepMind para construir algoritmos de aprendizado por reforço com JAX. Combinamos RLax com JAX, Haiku e Optax para construir um agente Deep Q-Learning (DQN) que aprende a resolver o ambiente CartPole. Em vez de usar um framework de RL totalmente empacotado, montamos o pipeline de treinamento nós mesmos para que possamos entender claramente como os componentes centrais do aprendizado por reforço interagem. Definimos a rede neural, construímos um buffer de replay, calculamos os erros de diferença temporal com RLax e treinamos o agente usando otimização baseada em gradiente. Além disso, nos concentramos em entender como o RLax fornece primitivos de RL reutilizáveis que podem ser integrados em pipelines de aprendizado por reforço personalizados. Usamos JAX para computação numérica eficiente, Haiku para modelagem de rede neural e Optax para otimização. Copiar Código Copiado Use um navegador diferente!pip -q install "jax[cpu]" dm-haiku optax rlax gymnasium matplotlib numpy import os os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" import random import time from dataclasses import dataclass from collections import deque import gymnasium as gym import haiku as hk import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax import rlax seed = 42 random.seed(seed) np.random.seed(seed) env = gym.make("CartPole-v1") eval_env = gym.make("CartPole-v1") obs_dim = env.observation_space.shape[0] num_actions = env.action_space.n def q_network(x): mlp = hk.Sequential([ hk.Linear(128), jax.nn.relu, hk.Linear(128), jax.nn.relu, hk.Linear(num_actions), ]) return mlp(x) q_net = hk.without_apply_rng(hk.transform(q_network)) dummy_obs = jnp.zeros((1, obs_dim), dtype=jnp.float32) rng = jax.random.PRNGKey(seed) params = q_net.init(rng, dummy_obs) target_params = params optimizer = optax.chain( optax.clip_by_global_norm(10.0), optax.adam(3e-4), ) opt_state = optimizer.init(params) Instalamos as bibliotecas necessárias e importamos todos os módulos necessários para o pipeline de aprendizado por reforço. Inicializamos o ambiente, definimos a arquitetura da rede neural usando Haiku e configuramos a Q-network que prevê os valores das ações. Também inicializamos os parâmetros da rede e da rede alvo, bem como o otimizador a ser usado durante o treinamento. Copiar Código Copiado Usar um navegador diferente!@dataclass class Transition: obs: np.ndarray action: int reward: float discount: float next_obs: np.ndarray done: fl

aprendizado por reforçoRLaxDeep Q-Learning