<?xml version="1.0" encoding="utf-8"?>
<?xml-stylesheet type="text/xsl" href="../assets/xml/rss.xsl" media="all"?><rss version="2.0" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Robot Learning by Example (Posts about flax)</title><link>https://engyasin.github.io/</link><description></description><atom:link href="https://engyasin.github.io/categories/flax.xml" rel="self" type="application/rss+xml"></atom:link><language>en</language><copyright>Contents © 2026 &lt;a href="mailto:yy33@tu-clausthal.de"&gt;Yasin Yousif&lt;/a&gt; </copyright><lastBuildDate>Sat, 23 May 2026 16:22:09 GMT</lastBuildDate><generator>Nikola (getnikola.com)</generator><docs>http://blogs.law.harvard.edu/tech/rss</docs><item><link>https://engyasin.github.io/posts/reinforcement-learning-with-bells-and-whistles/</link><dc:creator>Yasin Yousif</dc:creator><description>&lt;center&gt;
&lt;br&gt;
&lt;img width="100%" src="https://engyasin.github.io/images/rlwithbells/rlwithbells.png"&gt;
&lt;br&gt;
&lt;/center&gt;

&lt;p style="font-size:0.8em"&gt;
*Reinforcement Learning (RL) represents a powerful framework for solving sequential decision-making problems in dynamic environments across diverse domains, such as control of robots or optimization of profit. However, its practical implementation requires navigating a variety of software packages, encompassing deep learning libraries (e.g., TensorFlow, PyTorch, JAX/Flax), environment frameworks (e.g., Gymnasium, Numpy), and hyperparameter optimization techniques and libraries. This post critically evaluates the common PyTorch, Gymnasium, and NumPy RL stack by comparing it to a faster alternative: JAX/Flax for both of the model training and simulation of environments. A Gridworld example evaluating both training speed and accuracy is utilized to test these packages. Additionally, we complement this example by a comprehensive tracking and monitoring of the training process using MLflow along with a thorough hyperparameters optimization via Optuna. The post concludes with a discussion of the results and final recommendations for optimal use cases of each of these packages.* 
&lt;/p&gt;

&lt;p class="more"&gt;&lt;a href="https://medium.com/data-science-collective/reinforcement-learning-with-bells-and-whistles-2f9d6d0a7ae3"&gt;Read more on Medium ...&lt;/a&gt;&lt;/p&gt;

&lt;!-- Alternative Title: --&gt;

&lt;!-- Gif of learning across time for all envs --&gt;

&lt;!--

&lt;center&gt;
&lt;br&gt;
&lt;img width="100%"src='/images/rlwithbells/rlwithbells.png'&gt;
&lt;br&gt;
Figure 1: The popular eco-system for modular and scalable training of RL agents. 
&lt;/center&gt;


----------

**Table of Content**

----------

[TOC]

----------



# Introduction &amp; Installation

The typical workflow for applying Reinforcement Learning to optimize an objective involves defining Markov Decision Process (MDP) quantities, such as state and action spaces, actor model (agent), and a reward function [24]. Furthermore, an environment model is 
required for simulating the forward application of our RL agent in model-free algorithms. The training process alternates between collecting experience data (rollouts) and training the agent on that data. Consequently, the runtime of our program 
is influenced by two key components: neural network parameter updates and environment simulation. OpenAI-Gym, initially proposed in [1], and its successor Gymnasium [2] are well-established Python libraries providing a structured approach to 
building RL simulation environments. Popular libraries like TensorFlow and PyTorch are commonly used for training the agent model itself. This paper explores JAX [21] and its neural network extension, Flax [20], as a promising alternative for 
both simulation and training, aiming to accelerate training and improve optimization.

Our tests on the GridWorld environment indicate that using JAX for environment batching yields significant speedups on GPU hardware, achieving performance levels comparable to existing solutions. We also focused on the hyperparameter search 
problem, which is particularly critical in Reinforcement Learning due to its interactive nature. We employed the Optuna [6] implementation of advanced hyperparameter search methods, demonstrating its impact on the results.  All trials and 
experiments were tracked using MLflow [5], providing a detailed overview of key metrics.

Beyond that, each implementation section begins with a concise overview of the package's capabilities and main functions, enabling readers to effectively utilize them in their projects. Finally, the main experimental results are presented and 
discussed, followed by concluding takeaways.


&lt;!-- In the following we will go over all the different libraries, highlighting their practical advantages or disadvantages and their main key difference from their alternatives. Starting with Gymnasium, MLflow, Optuna, and lastly Jax and Flax. We also will show the results of utilizing these libraries on our test program (explained below in Figure 4), focusing on performance (ability to converge optimally) and training time, on our hardware **NVIDIA GeForce RTX 5060 Ti**.--&gt;

&lt;!--

The installation of the packages needed in this post with `pip`-python can be done simply as follows:


&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;pip&lt;span class="w"&gt; &lt;/span&gt;install&lt;span class="w"&gt; &lt;/span&gt;gymnasium
pip&lt;span class="w"&gt; &lt;/span&gt;install&lt;span class="w"&gt; &lt;/span&gt;mlflow
pip&lt;span class="w"&gt; &lt;/span&gt;install&lt;span class="w"&gt; &lt;/span&gt;optuna

&lt;span class="c1"&gt;#replace with your cuda version &lt;/span&gt;
pip&lt;span class="w"&gt; &lt;/span&gt;install&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;jax[cuda12]&amp;quot;&lt;/span&gt;
pip&lt;span class="w"&gt; &lt;/span&gt;install&lt;span class="w"&gt; &lt;/span&gt;flax
&lt;/pre&gt;&lt;/div&gt;




# Gymnasium: Standardize Your Environment 

[Gymnasium](https://gymnasium.farama.org/v0.29.0/)  is an updated version of the popular Gym package, originally developed by OpenAI [1]. It provides a collection of standardized simulated environments with unified interfaces, which are 
regularly updated. This standardization is beneficial for benchmarking different RL algorithms, as well as for improving readability and collaboration. Several other advantages motivate the use of Gym and Gymnasium:

- **Vectorized Environments** (`VecEnv`): This feature allows running multiple instances of the same environment concurrently, enabling batching of states and actions. This significantly speeds up trajectory rollout and, consequently, the 
training of the RL agent. There are two methods for deploying vectorized environments in Gymnasium: *Synchronized* and *Asynchronous* environments. A comparison of these two is presented in Table 1 below.


&lt;center&gt;
Table 1: Comparison between Gymnasium vectorization methods `SyncVectorEnv` and `AsyncVectorEnv`

&lt;br&gt;
&lt;br&gt;
&lt;table style="border: 1px solid black" &gt;

&lt;tr &gt;
&lt;th style="border: 1px solid black"&gt;
 `gymnasium.Vector.SyncVectorEnv` 
&lt;/th&gt;
&lt;th style="border: 1px solid black"&gt;
 `gymnasium.Vector.AsyncVectorEnv` 
&lt;/th&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="border: 1px solid black"&gt; creates all environments in one thread serially and batch the output (state,reward,done flags) &lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; each environment is created with its own subprocess (computational thread) &lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="border: 1px solid black"&gt; best used when environment process is simple, and faster than running independent subprocesses for each instance. &lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; best used when the environment processes are computationally expensive and there's enough memory for subprocesses. &lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td colspan='2' style="text-align: center;border: 1px solid black"&gt;
Input to both functions should be a list of creation functions of environments (e.g., using a lambda function).
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr &gt;
&lt;td colspan='2' style="text-align: center;border: 1px solid black"&gt;
If you set the optional key input (`shared_memory`) to True, then the output observation data will be referenced directly without copying, which can speed up the stepping when its size is large.
&lt;/td&gt;
&lt;/tr&gt;
&lt;/table&gt;

&lt;/center&gt;

- **Spaces Objects**: These are used to define the state and action values and distributions. These spaces represent specific constraints. Examples of possible space sets are shown in Figure 2, imported from `gymnasium.spaces`.


&lt;center&gt;
&lt;br&gt;
&lt;img width="100%"src='/images/rlwithbells/gymspaces.png'&gt;
&lt;br&gt;
Figure 2: Gymnasium basic and compound spaces.
&lt;/center&gt;

- **Registry**: Custom environments can be registered within the installation so that they can be instanced directly like a standard Gym package (with `gym.make`). 


- **Wrappers**: The `gymnasium.wrappers` module contains useful classes to *modify* specific environment behavior. Examples include:

    - `ObservationWrapper`: Modifies the observation space.
    - `ActionWrapper`: Modifies the action space.
    - `RewardWrapper`: Modifies the reward function.
    - `TimeLimit`: Used to truncate an episode after a specific number of steps.
    - `AutomaticReset`: When the environment reaches a terminal state or is truncated, this wrapper resets it on the next call to `.step()`, returning the last observed state.
    - `RecordEpisodeStatistics`: Important for collecting episodic rewards, which indicate the success or failure of a policy during training.

- If your environment is a subclass of `gymnasium.Env`, you benefit from automatic testing using the `gymnasium.utils.env_checker.check_env` function, which performs common tests on the Gym environment methods and their spaces.

Additionally, Gymnasium introduces the following changes over Gym:


- **Termination and Truncation**: Instead of the `done` flag, Gymnasium uses `termination` and `truncation` flags. *Termination* occurs naturally when the episode's goal is achieved (e.g., the goal is reached), while *truncation* happens only 
after a specific number of steps to prevent episodes from running indefinitely. Figure 3 depicts these differences.


&lt;center&gt;
&lt;br&gt;
&lt;img width="80%"src='/images/rlwithbells/termination.png'&gt;
&lt;br&gt;
Figure 3: Difference between terminating (goal achieved) and truncating (time limit reached) a simulated episode.
&lt;/center&gt;

- **Functional Environment Creation**: A new and experimental function for environment creation, `gymnasium.experimental.functional.FuncEnv()`, is introduced. This function utilizes a purely functional structure (as the environment class is 
stateless) to more closely reflect the formulation of POMDP (Partial Observable Markov Decision Process). Additionally, this structure facilitates direct compatibility with JAX.




## Example: Creating a Custom Gym Environment and Training with DQN

This post demonstrates the application of Gym and associated libraries using a custom Gridworld environment called 
*Doors*. In this environment, an agent occupies a cell within a grid and is tasked with navigating towards a goal 
cell by passing through one of three gaps (doors) that divide the grid in two, as illustrated in Figure 4. The figure also depicts state-action configurations.

&lt;center&gt;
&lt;br&gt;
&lt;img width="30%" src='/images/ilpost/gifs/expert.gif'&gt;
&lt;img width="70%" src='/images/ilpost/Doors_.png'&gt;
&lt;br&gt;
Figure 4: The *Doors* environment, introduced in [previous 
post](https://www.rlbyexample.net/posts/hands-on-imitation-learning/). The lower image shows an animation of the 
optimal policy for solving this environment.
&lt;/center&gt;

&gt; **Note:** The complete code repository is available [here](https://github.com/engyasin/ilsuvey), with the final 
script integrating all libraries located [here](https://github.com/engyasin/ilsuvey/blob/main/dqn_hopt_flax.py).

Below, we present sections of the environment creation class in Gymnasium:



&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;gymnasium&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="k"&gt;as&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;gym&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;numpy&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="k"&gt;as&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;gymnasium.wrappers&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Autoreset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;RecordEpisodeStatistics&lt;/span&gt;

&lt;span class="c1"&gt;#creating the environment&lt;/span&gt;
&lt;span class="k"&gt;class&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nc"&gt;DoorsGym&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;gridSize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;&lt;span class="n"&gt;nDoors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;render_frames&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gridSize&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gridSize&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nDoors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nDoors&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;render_frames&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;render_frames&lt;/span&gt;

        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Discrete&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# representing the four states of cells for the entire size of the grid (flattened)&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;MultiDiscrete&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prod&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gridSize&lt;/span&gt;&lt;span class="p"&gt;))])&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;options&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;pass&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="k"&gt;pass&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;



This environment can then be utilized in a separate script as shown in the following figure:

&lt;center&gt;
&lt;br&gt;
&lt;img width="100%" src='/images/rlwithbells/gymcode.png'&gt;
&lt;br&gt;
Figure 5: Environment registration and creation.
&lt;/center&gt;

Given that the action space for this environment is discrete, we selected the Deep Q-Network (DQN) training algorithm 
[3], based on the CleanRL [4] implementation, to learn the policy.  The subsequent sections detail how to track 
training metrics and plot them against training time using MLflow.


# MLflow: Tracking RL Experiments

[*Mlflow*](https://mlflow.org/) is a popular Python library for tracking, versioning, collaborating on, and deploying 
machine learning models. Its primary functionality involves displaying training metrics either on a local server (by 
running `mlflow ui` in a new terminal, with the default port set to 5000) or on an online cloud server such as 
*Databricks*.

**Mlflow** organizes training by creating an *experiment* for each machine learning task (for example, cat/dog image 
classification). Within each experiment, multiple *runs* can be defined, representing individual training trials for 
that task (e.g., different ML approaches for the same task). Furthermore, smaller runs can be *nested* within larger 
runs (which we will utilize for our hyperparameter trails, described below).

This structure enables comprehensive saving of all testing parameters and metrics, and provides a unified interface 
for tracking them. Additionally, **Mlflow** offers seamless integration with PyTorch, TensorFlow, and Keras, along 
with numerous other functionalities and features that are beyond the scope of this document but can be explored on 
its website [https://mlflow.org/](https://mlflow.org/).

To initiate a new experiment in **Mlflow**, execute the following command: 
`mlflow.create_experiment('experiment_name')`. This defines a new task for training a machine learning model or 
allows continuation of an existing experiment. The code for working with an existing experiment is as follows:


&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;mlflow&lt;/span&gt;
&lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_tracking_uri&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;http://localhost:5000&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_experiment&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;runs/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;experiment_name&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;



Note that you must specify the URI where the server will publish the results (in this case, `http://localhost:5000`) 
while the local server is running in a separate terminal using the command `mlflow ui`.

Subsequently, you can start a specific `run` within the experiment or create *multiple child runs* nested within the 
parent run by utilizing the `nested` keyword argument. The latter is particularly useful for hyperparameter 
optimization, where each trial can be tracked independently in its own child run. The following code illustrates this 
functionality.




&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;start_run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;run_name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;main_run&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;nested&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

    &lt;span class="c1"&gt;# log main parameters here&lt;/span&gt;
    &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;MainConfigs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;start_run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nested&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# train here&lt;/span&gt;

        &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;argsDict&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_metric&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;metric_name&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;metric_value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;step&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;global_step&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set_tag&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;label&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_figure&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# matplotlib figure object&lt;/span&gt;
        &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_image&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# numpy array and PIL image&lt;/span&gt;

        &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pytorch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;save_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# saving pytorch model on the server&lt;/span&gt;

        &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_artifacts&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# saving other data types&lt;/span&gt;

    &lt;span class="c1"&gt;# save final model&lt;/span&gt;
    &lt;span class="n"&gt;model_uri&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s1"&gt;&amp;#39;copy from dashboard usually starting with models:/&amp;#39;&lt;/span&gt;
    &lt;span class="n"&gt;model_info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pytorch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pytroch_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;model_uri&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# load the model&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pytorch&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;load_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_uri&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;



Then, in a new browser tab, navigate to the URL `http://localhost:5000` to view all experiments. If you select an 
active experiment, you can track individual runs either as a list or, as illustrated in Figure 6, as a chart 
displaying all tracked metrics.
&lt;center&gt;
&lt;br&gt;
&lt;img width="100%" src='/images/rlwithbells/mlflowcharts.png'&gt;
&lt;br&gt;
Figure 6: The MLflow interface (chart-view) displaying tracked parameters for the active run.
&lt;/center&gt;

# Optuna: Optimizing RL Hyperparameters

Reinforcement Learning (RL) training typically requires the tuning of numerous hyperparameters, exceeding the number 
required for supervised learning counterparts. Therefore, applying efficient hyperparameter optimization methods such as Bayesian optimization [7] or HyperBand [16] is highly beneficial. In the following sections, we will begin by reviewing prominent hyperparameter optimization methods, focusing on their implementation using the `Optuna` package.

These hyperparameters in the context of RL include parameters such as the learning rate, episode length, discount 
factor (in the Bellman equation), as well as the agent network depth and architecture.

## Types of Hyperparameter Optimization Methods:
Generally, there are four main branches of hyperparameter optimization methodologies, differing in their complexity 
and approach, as depicted in the following figure.
&lt;center&gt;
&lt;br&gt;
&lt;img width="100%" src='/images/rlwithbells/hyperopt.png'&gt;
&lt;br&gt;
Figure 7: Main search methodologies for optimizing hyperparameters of machine learning models.
&lt;/center&gt;

### Uninformed Methods
These methods represent the simplest approach, involving the manual testing of different samples directly from the 
search space. Depending on their sampling strategy, they can be categorized as:

- **Manual:** Samples are chosen manually by the developer.

- **Uniform:** Samples are chosen uniformly distributed within the specified range.

- **Random:** Samples are chosen randomly within the specified range.



### Bayesian Optimization Methods
This category of methods utilizes a surrogate model to approximate the **objective function** – the function that 
estimates the learning objective (e.g., accuracy or negative loss) given the training hyperparameters. The training 
data for this model consists of the values from previous training attempts.  The new set of hyperparameters to be 
tested is then proposed by another model called an **acquisition function**.

Based on the nature of the surrogate model, Bayesian Optimization (BO) methods [7] can be categorized as follows:

- **Sequential Model-based Algorithmic Configuration (SMAC)** [8]: Employs a random forest to approximate the 
objective function, making it suitable for searching categorical and discrete parameters.

- **Sequential Model-based Bayesian Optimization (SMBO)** [9]: Utilizes a Gaussian Process model, suitable for 
continuous hyperparameters.

- **Tree-structured Parzen Estimators (TPE):** Employs a random forest, suitable for large search spaces encompassing both continuous and discrete parameters, with fast run-time. In `Optuna`, its implementation enables the learning of interactive relationships between different hyperparameters. Its Optuna function is `optuna.samplers.TPESampler`.

- **MATIS** [10]: Gaussian Process-based, also utilizing a Gaussian Mixture Model as its acquisition function.


### Heuristic Search

This branch of methods samples the hyperparameters for the next training iteration within the neighborhood of the 
best set of hyperparameters found so far. The definition of this neighborhood significantly impacts search 
performance, leading to various variants:

- **Simulated Annealing (SA)** [11]: Searches for the next sample around the best or next-to-best set of values found so far, aiming to avoid local minima.

- **Genetic Algorithm** [12]: Applies evaluation-inspired methods to select the next set of values. This typically 
involves pairing the best samples found for different parameters or mutating them.

- **Particle Swarm Optimization** [13]: This method specifically focuses on continuous hyperparameters.

- **Population-based Training** [14]: This method specializes in neural network optimization, searching for both 
hyperparameters and standard training parameters. It gradually adds new layers to the model during training, while 
retaining the previously trained layers. However, it cannot recover the exact best hyperparameters for the best 
model, as it only finds the parameters of the final trained model.


### Multi-Fidelity Optimization (MFO)

This approach enhances hyperparameter optimization by enabling faster training through early stopping on less 
promising samples, achieved by training on subsets of the data or for a reduced number of epochs (as in Optuna). This is more efficient than full training for all samples, as it avoids unnecessary computational resources spent on 
evaluating many samples with a low probability of being optimal, while focusing on areas with promising performance. 
MFO methods frame this concept as resource management algorithms. Notably, MFO methods can be directly combined with 
the aforementioned sampling methods, addressing different aspects of the optimization problem. In Optuna, MFO methods are termed **Pruners**, while the sampling methods are called **Samplers**.

Popular MFO methods include:

- **Coarse-to-Fine Pruner:** As the name suggests, this method initiates training with a small number of samples and 
gradually focuses on a more promising subset.

- **Successive Halving (SH)** [15]: This method allocates computational resources strategically across different 
training trails.

- **HyperBand (HB)** [16]: This method defines pairs of candidate numbers and their allocated resources, called 
*brackets*, and initiates full training on a subset of these brackets. This prevents prematurely discarding promising candidates, a potential issue with SH due to shallow training. Its Optuna function is: 
`optuna.pruners.HyperbandPruner`.

- **Bayesian Optimization HyperBand (BOHB):** Combining a Bayesian Optimization sampler with a HyperBand pruner often yields improved results, as detailed in [17]. In Optuna, this can be achieved by setting the sampler to TPE and the 
pruner to HB.

## Steps for Hyperparameter Optimization in Optuna:

Hyperparameters in RL training programs are numerous and have varying effects on the training process. Therefore, 
manually tuning them requires significant experience and experimentation to identify optimal values. Optuna [6] 
provides a direct and efficient solution for saving effort in practical applications by utilizing automated search 
with well-established implementations. Optuna simplifies this process with clear implementation steps, supporting 
most of the aforementioned methods and offering direct integration with libraries like MLflow, PyTorch, and JAX. 
These steps can be summarized as follows:

1. Define the objective function, which, in the case of RL, returns the average episodic return.
2. Within this objective function, define the ranges and types of hyperparameters to be optimized using the 
`optuna.trail.suggest_` group of functions.
3. Initialize the optimization object (called the *study*) using `create_study()` and define the desired `sampler` 
and `pruner` methods, along with the *direction* (defaulting to minimization).
4. Optionally, save the current training session by passing the `storage` argument (a database URL) to 
`create_study()`. This allows for resuming training from a saved session of trails by passing `load_if_exists=True` 
to the same function.
5. Initiate training using the `.optimize()` method of the previous study object, passing the objective function as a callable and the number of trials.
6. Upon completion of optimization, the best set of parameters can be accessed via `study.best_params`, and the 
trained model can be saved.

It is also worth noting that Optuna includes a visualization module (`optuna.visualization`) whose functions take the optimized study object as input and generate various useful plots, such as those illustrating the most influential hyperparameters on the results. This module requires the installation of the `plotly` package.


In the following we show some illustrative code snippet to implement the above steps.

## Training Code Structure in Optuna



&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;optuna&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;optuna.samplers&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;TPESampler&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;optuna.pruners&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;HyperbandPruner&lt;/span&gt;

&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;functional&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;


&lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;objective&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trail&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;{}):&lt;/span&gt;

    &lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;num_steps&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;suggest_int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;num_steps&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;17&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;step&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)})&lt;/span&gt;
    &lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;learning_rate&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;suggest_float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;learning_rate&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1e-1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)})&lt;/span&gt; 
    &lt;span class="c1"&gt;# log argument makes it more probable to sample lower values.&lt;/span&gt;
    &lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;buffer_size&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;suggest_int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;buffer_size&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt; &lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;48&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;step&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)})&lt;/span&gt;
    &lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;batch_size&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;suggest_int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;batch_size&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;step&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;)})&lt;/span&gt;
    &lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;train_frequency&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;suggest_int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;train_frequency&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;24&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;step&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)})&lt;/span&gt;
    &lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;optimizer_name&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;suggest_categorical&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;optimizer_name&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;Adam&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;quot;SGD&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;])})&lt;/span&gt;

    &lt;span class="c1"&gt;# define network&lt;/span&gt;

    &lt;span class="c1"&gt;# define optimizers&lt;/span&gt;

    &lt;span class="c1"&gt;# training loop&lt;/span&gt;

    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;start_run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nested&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

        &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# training logic&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;NumberOfEpochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="c1"&gt;# training logic&lt;/span&gt;


            &lt;span class="c1"&gt;# send the final metrics&lt;/span&gt;
            &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_metrics&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;charts/episodic_return&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;infos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;episode&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;r&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;finished&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
                                &lt;span class="s2"&gt;&amp;quot;charts/episodic_length&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;infos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;episode&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;l&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;finished&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
                                &lt;span class="s2"&gt;&amp;quot;charts/epsilon&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="s2"&gt;2f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
                                &lt;span class="n"&gt;step&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;global_step&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="c1"&gt;# break training whenever sample seems not optimal (early stopping)&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;should_prune&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
                &lt;span class="k"&gt;raise&lt;/span&gt; &lt;span class="n"&gt;optuna&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;TrialPruned&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;


    &lt;span class="c1"&gt;# return average episodic reward (objective)&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;start_run&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;run_name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;run_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;run&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

    &lt;span class="n"&gt;study&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optuna&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;create_study&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sampler&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;TPESampler&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;multivariate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;False&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; 
                                &lt;span class="c1"&gt;# if multivariate is true the sampler can learn the mutual interactions of variables&lt;/span&gt;
                                &lt;span class="n"&gt;pruner&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;HyperbandPruner&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;min_resource&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;240&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_resource&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;max_epochs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reduction_factor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;#resource represents epochs&lt;/span&gt;
                                &lt;span class="n"&gt;direction&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;maximize&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# objective function should be passed as callable without arguments to optimize method&lt;/span&gt;
    &lt;span class="n"&gt;objective_func&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;objective&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;vars&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;argsParams&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;study&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;objective_func&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_trials&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# hom many trails to test&lt;/span&gt;

    &lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;study&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_parameters&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# the results&lt;/span&gt;

    &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;study&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# log it with mlflow&lt;/span&gt;

    &lt;span class="c1"&gt;# visualizations require plotly installed&lt;/span&gt;

    &lt;span class="n"&gt;plotly_fig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optuna&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;visualization&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;plot_param_importances&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;study&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;evaluator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
    &lt;span class="n"&gt;plotly_fig&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="c1"&gt;# evaluator is optuna.importance.FanovaImportanceEvaluator by default or optuna.importance.MeanDecreaseImpurityImportanceEvaluator&lt;/span&gt;

    &lt;span class="n"&gt;plotly_fig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optuna&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;visualization&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;plot_contour&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;study&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;plotly_fig&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;


    &lt;span class="n"&gt;plotly_fig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optuna&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;visualization&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;plot_optimization_history&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;study&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;plotly_fig&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="c1"&gt;# these images can be viewed in new widows or sent to MLflow server to view them alongside other parameters&lt;/span&gt;

    &lt;span class="n"&gt;mlflow&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;plotly_fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;artifact_file&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;opt_history.html&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
&lt;/pre&gt;&lt;/div&gt;




In our accompanying [code repository](https://github.com/engyasin/ilsuvey), we conducted 40 trials to search for 
optimal hyperparameters, and visualized the results using the following methods:

- **Parameter Importance:**  We used `optuna.visualization.plot_param_importances` to assess the relative importance 
of each hyperparameter on model training performance.  The two most influential parameters were found to be episode 
length and learning rate.

&lt;center&gt;
&lt;br&gt;
&lt;img width="100%" src='/images/rlwithbells/params_importance.png'&gt;
&lt;br&gt;
Figure 8: Hyperparameters' estimated relative importance on model training performance. The episode length and 
learning rate is found to be the most influential parameters.
&lt;/center&gt;

- **Interactive Pairwise Importance:**  We generated 2D heatmaps using `optuna.visualization.plot_contour` to 
visualize the interactive importance of parameter pairs.  Darker regions indicate optimal combinations of these 
parameters.

&lt;center&gt;
&lt;br&gt;
&lt;img width="100%" src='/images/rlwithbells/contours.png'&gt;
&lt;br&gt;
Figure 9: 2D heatmaps illustrating the interactive importance of parameter pairs on model performance.  The darker 
regions highlight the most effective parameter combinations.
&lt;/center&gt;


- **Optimization History:**  We plotted the performance of trails over time using 
`optuna.visualization.plot_optimization_history` to track the progress of the optimization process.

&lt;center&gt;
&lt;br&gt;
&lt;img width="100%" src='/images/rlwithbells/optimization_history.png'&gt;
&lt;br&gt;
Figure 10:  Improvement in trail performance over the course of training.  The results demonstrate that the 
hyperparameter optimization process progressively identified better parameter sets, leading to improved 
performance.  Further search is expected to continue this upward trend.
&lt;/center&gt;

Finally, these visualizations provide valuable insights into the effective ranges and combinations of 
hyperparameters that yield optimal performance. This information can potentially guide manual enhancements to other configuration components and deepen the understanding of their effects in the optimization process.





&lt;!-- ### Optimizing the hyper parameters of the trained agent --&gt;

&lt;!-- ### Result --&gt;

&lt;!--

# JAX &amp; Flax: Accelerating Environment Rollout and Model Training

A common approach to training reinforcement learning (RL) agents using simulated environments involves utilizing 
PyTorch or TensorFlow for agent training and NumPy/Gym for environment simulation. However, Google's **JAX** [*(Just After Execution)*](https://docs.jax.dev/en/latest/index.html) presents an increasingly popular alternative. JAX offers a faster method for performing matrix computations efficiently (replacing NumPy) and training neural network models (replacing PyTorch) by leveraging hardware acceleration on devices like GPUs and TPUs. While JAX can be directly used to update neural network parameters, a targeted JAX-based package like Flax simplifies model 
structuring and algorithm development.

In the following sections, we will highlight key features of JAX, focusing on its NumPy-compatible functionalities. 
We will demonstrate these features later by rewriting the Doors Gym environment in JAX and comparing its runtime 
performance with the original implementation.

- JAX employs the **XLA** *(Accelerated Linear Algebra)* compiler to translate code into a statically typed 
expression language called **Jexpr**. This compiled code executes faster on CPUs, GPUs, and TPUs.  Specifically, JAX functions can be compiled by passing them to the `jax.jit()` function or by using the `@jax.jit` decorator directly above their definitions.

- JAX largely replaces NumPy functions with similarly named counterparts, minimizing code modification efforts.  
Typically, replacing `import jax.numpy as np` with `import numpy as np` is sufficient.  However, certain 
considerations are important, as detailed below:


&gt; *Note*: Unlike NumPy, JAX arrays are immutable. Consequently, in-place modification is not possible. Instead, array elements must be updated using operations like:


&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;jax&lt;/span&gt;
&lt;span class="n"&gt;arr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;numpy&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;arr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;arr&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;at&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# equivalent to arr[1] += 2 in NumPy&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;



&gt; *Note*: JAX arrays do not raise an `OutofIndex` error when accessing elements outside their bounds; instead, they 
default to returning the last element in the array.

&gt; *Note*: JAX defaults to `float32` precision, unlike NumPy's `float64`.

&gt; *Note*: JAX provides alternative implementations of SciPy functions through the `jax.scipy` module.

The following code shows an example of a JAX compatible function compiled with jit, measuring its runtime



&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;jax&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;time&lt;/span&gt;

&lt;span class="n"&gt;arr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;numpy&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;35&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# 7x5 array&lt;/span&gt;

&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;JAX running on : &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;arr&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nd"&gt;@jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jit&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;ATA&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# run in IPython :&lt;/span&gt;
&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="n"&gt;timeit&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt; &lt;span class="n"&gt;ATA&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;arr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;block_until_ready&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;




- JAX's `jax.vmap()` function (or the `@jax.vmap` decorator) enables automatic vectorization of functions, 
facilitating parallel processing of multiple inputs. Instead of iterating through each input individually, you can 
pass them as a *batch* to achieve significant speed improvements over standard Python and NumPy code.  The input and output are effectively stacked and concatenated, adding a new dimension to their matrices (the placement of this dimension is configurable). We demonstrate that this approach is also faster than the Gymnasium environment 
vectorization methods.

- JAX also supports vectorization across computational resources, enabling parallel processing. This functionality is implemented similarly to `vmap`, using either `jax.pmap()` or the `@jax.pmap` decorator.

&gt; *Note*: JAX execution is, by default, asynchronous. This means that code returns immediately before calculating the output of a function. To ensure the function completes before returning, use `.block_until_ready()` to append the function call.

- Beyond compilation with XLA, JAX effectively calculates gradients through **automatic differentiation** 
(*autodiff*) of all variable calculations. This is particularly beneficial for accelerating the training of neural 
networks.

- Control statements (*for, while, if, switch*) are known performance bottlenecks in Python. In JAX, these can be 
replaced with functional equivalents as follows:



&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;jax&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;lax&lt;/span&gt;

&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cond&lt;/span&gt; &lt;span class="c1"&gt;# if&lt;/span&gt;
&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;switch&lt;/span&gt; &lt;span class="c1"&gt;# switch, case&lt;/span&gt;
&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;while_loop&lt;/span&gt; &lt;span class="c1"&gt;# while&lt;/span&gt;
&lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fori_loop&lt;/span&gt; &lt;span class="c1"&gt;# for&lt;/span&gt;

&lt;span class="c1"&gt;# example for fori_loop&lt;/span&gt;
&lt;span class="nd"&gt;@jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jit&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;main&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;for_loop_body&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;accumulator&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;accumulator&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;accumulator&lt;/span&gt;

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;accumulator&lt;/span&gt;

    &lt;span class="n"&gt;accumulator&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;
    &lt;span class="n"&gt;init_val&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;accumulator&lt;/span&gt;
    &lt;span class="n"&gt;start_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;end_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;

    &lt;span class="n"&gt;final_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fori_loop&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;start_i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;end_i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;for_loop_body&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;init_val&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;



&gt; *Note*: For code to be compiled or vectorized correctly in JAX, it must be exclusively *functional*. 
Object-oriented code (such as classes with stateful attributes) cannot be compiled. However, stateless class objects can be used, provided they do not retain internal variables (or use them solely as static variables). If these variables are modified, they are inherently part of the state.

&gt; *Note*: This functional code restriction should not be viewed as a limitation. In fact, functional code is commonly considered more readable and better structured.

- The following code snippet presents an example of our Doors environment converted to a *stateless* class, while 
remaining compatible with Gymnasium.  Specific new functions are explained in the comments.



&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;gymnasium&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="k"&gt;as&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;gym&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;cv2&lt;/span&gt;

&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;functools&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;partial&lt;/span&gt;

&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;jax&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;jax&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;jit&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;jax.numpy&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="k"&gt;as&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;jax&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;lax&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pmap&lt;/span&gt;


&lt;span class="k"&gt;class&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nc"&gt;DoorsEnvJax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;gridSize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;&lt;span class="n"&gt;nDoors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nb"&gt;super&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="fm"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="c1"&gt;# Static variables - not to be changed: otherwise an error is thrown.&lt;/span&gt;
        &lt;span class="n"&gt;EnvConfig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gridSize&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gridSize&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nDoors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nDoors&lt;/span&gt;

        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Discrete&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;MultiDiscrete&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gridSize&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gridSize&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])])&lt;/span&gt;

        &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions_vocal&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


    &lt;span class="nd"&gt;@partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;jit&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;static_argnums&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt; &lt;span class="c1"&gt;# ignore the first (self) input&lt;/span&gt;
    &lt;span class="nd"&gt;@partial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vmap&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;in_axes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="c1"&gt;# vectorize along the first dimension (order 0) of all inputs except the first (None)&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;env_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env_state&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env_state&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;agent_location&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;agent_location&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;goal_location&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;goal_location&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;episodic_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;episode&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;r&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;timestep&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;episode&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;l&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;max_steps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;num_steps&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;


        &lt;span class="n"&gt;movement&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions_vocal&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;new_location&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;clip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;agent_location&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="n"&gt;movement&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gridSize&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="kc"&gt;False&lt;/span&gt;
        &lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_steps&lt;/span&gt;&lt;span class="o"&gt;&amp;lt;=&lt;/span&gt;&lt;span class="n"&gt;timestep&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bool_&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
        &lt;span class="n"&gt;past_position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;agent_location&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="c1"&gt;# check if wall (2)&lt;/span&gt;

        &lt;span class="n"&gt;cell_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;at&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;new_location&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;get&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# array elements are returned by .get()&lt;/span&gt;

        &lt;span class="n"&gt;possible_moves&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logical_or&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cell_state&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cell_state&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# conditions should be performed by jax functions&lt;/span&gt;

        &lt;span class="c1"&gt;# boolean indexing can be done utilizing jax.np.where&lt;/span&gt;
        &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;possible_moves&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="c1"&gt;# boolean mask array&lt;/span&gt;
                &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;at&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;agent_location&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;at&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;new_location&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;set&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="c1"&gt;# value if True&lt;/span&gt;
                &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="c1"&gt;# value if False&lt;/span&gt;
                 &lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;agent_location&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;new_location&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cell_state&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 

        &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_get_reward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;past_position&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;agent_location&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;goal_location&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;info&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_get_info&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;agent_location&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;goal_location&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

        &lt;span class="c1"&gt;# automatic reset&lt;/span&gt;
        &lt;span class="n"&gt;new_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logical_or&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
                 &lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="kc"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,:])[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="o"&gt;...&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="c1"&gt;# to remove vector dimension&lt;/span&gt;
                &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

        &lt;span class="n"&gt;info&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;update&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="s2"&gt;&amp;quot;new_state&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;new_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                     &lt;span class="s2"&gt;&amp;quot;episode&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:{&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;r&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;episodic_reward&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="s1"&gt;&amp;#39;l&amp;#39;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;timestep&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
                     &lt;span class="s2"&gt;&amp;quot;agent_location&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;new_state&lt;/span&gt;&lt;span class="o"&gt;==&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
                     &lt;span class="s2"&gt;&amp;quot;goal_location&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;new_state&lt;/span&gt;&lt;span class="o"&gt;==&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))})&lt;/span&gt;

        &lt;span class="c1"&gt;# Random keys should be used only once. Therefore we generate a new one each step.&lt;/span&gt;
        &lt;span class="n"&gt;new_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,:]&lt;/span&gt;

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;new_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;new_key&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;




As illustrated in the preceding example, the environment class is inherently vectorized, enabling the parallel 
execution of multiple environments by passing matrices of actions stacked along the first dimension. This is 
initialized within the `.reset()` function by generating a corresponding set of random keys. Specifically:


&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;   &lt;span class="n"&gt;key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;PRNGKey&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
   &lt;span class="n"&gt;NUM_ENVS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;24&lt;/span&gt; &lt;span class="c1"&gt;# vmap&lt;/span&gt;
   &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;NUM_ENVS&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# generate new keys from existing ones.&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;



This vectorization has proven to be significantly advantageous in our experiments. To confirm this, we evaluated the runtime performance for a range of DOORS environment counts, employing JAX, Gym Synchronized, Gym Asynchronous, and JAX with accelerated looping between steps (a common performance bottleneck in Python). The runtime results for these methods are presented in Figure 11 below.

&lt;center&gt;
&lt;br&gt;
&lt;img width="100%"src='/images/rlwithbells/runtimeEnvs.png'&gt;
&lt;br&gt;
Figure 11: Comparing runtime of different vectorization methods. JAX demonstrates resilience to increasing 
environment counts, maintaining performance up to 500 environments.  Accelerating the for loop resulted in 
exceptionally fast performance, achieving a runtime of only 0.07 seconds.
&lt;/center&gt;

**JAX-based environments exhibit no runtime degradation with increasing environment instances.** This 
observation is particularly noteworthy, as it allows for scaling up environment counts and accelerating the rollout 
phase in various reinforcement learning algorithms. The complete results and plotting script are available in the 
`display.py` script within the accompanying repository, facilitating reproducibility and allowing researchers to test the implementation on their own hardware. Furthermore, the Synchronized execution method consistently outperformed the Asynchronous version, likely due to the relatively simple environment stepping operations in DOORS, which minimize the overhead associated with spawning numerous subprocesses.

&lt;!-- compare run time with Gym: Plot (Number of envs, Time): Gym (list), Asyn, Sync, JAX  and how that impact the training --&gt;

&lt;!-- Here we are  --&gt;
&lt;!--

## FLAX

FLAX [20] is a specialized library built upon JAX for constructing and training neural networks. It is often favored over PyTorch or TensorFlow for deep learning due to JAX's inherent speed and improved readability.

In addition to FLAX, we leverage another JAX-based library, `optax` [21], to facilitate composable gradient 
transformations within JAX while defining the model and training state in FLAX.

FLAX's neural network classes inherit from `flax.linen.Module`. The forward pass of a network is defined within its `__call__()` function, annotated with `@flax.linen.compact`. This design results in an object-oriented network 
creation interface that remains stateless and compatible with Just-In-Time (JIT) compilation.

The following code illustrates the definition of a neural network in Flax and its application to a random input, a 
necessary step for initializing its parameters. It's important to note that these parameters are required inputs for model inference via the `.apply()` method, as the class is stateless.



&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;jax&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;flax&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;linen&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;


&lt;span class="k"&gt;class&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nc"&gt;MLP&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nd"&gt;@nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;compact&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="fm"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;activation&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;swich&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dense&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;


&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;MLP&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;main_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;PRNGKey&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;key1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;main_key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;random_data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key1&lt;/span&gt;&lt;span class="p"&gt;,(&lt;/span&gt;&lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;apply&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nb"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tabulate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;random_data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="c1"&gt;# shows the model structure&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;



A key advantage of FLAX is its automatic vectorization of network functions, eliminating the need for explicit 
`jax.vmap` calls. The batch dimension defaults to the first dimension, simplifying parallelization.

Following the definition of the network, we define the optimizer using `optax` and the training state class to manage the training process, as shown below:


&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="kn"&gt;from&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;flax&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;train_state&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nn"&gt;optax&lt;/span&gt;

&lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;TrainState&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;create&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;apply_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;apply&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;tx&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sgd&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;learning_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;momentum&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nd"&gt;@jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jit&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nf"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inputData&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;apply_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inputData&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;log_preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logsumexp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;jnp&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;log_preds&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;grads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_and_grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;train_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;train_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;apply_gradients&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;grads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;grads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;train_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss_value&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;




Using the preceding code, we can update the model's parameters based on the calculated loss. To persist the trained 
Flax model, we utilize the following code:


&lt;div class="code"&gt;&lt;pre class="code literal-block"&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nb"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;quot;wb&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
   &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;write&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;flax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;serialization&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;to_bytes&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="c1"&gt;# This code saves the model parameters in a data object. To load the parameters again, use:&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nb"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="s2"&gt;&amp;quot;r&amp;quot;&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
   &lt;span class="n"&gt;q_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;flax&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;serialization&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;from_bytes&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q_state&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="o"&gt;.&lt;/span&gt;&lt;span class="n"&gt;read&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/pre&gt;&lt;/div&gt;



&gt; *Note*: The `orbax` library provides a higher-level abstraction for automatically saving Flax models.

&lt;!-- TODO: check if we can log flax model to mlflow and load them? --&gt;

&lt;!--### Making the environment functional for JAX vectorization --&gt;

&lt;!-- ### Training DQN with Flax --&gt;

&lt;!-- ### Result : Speed comparison with Pytorch--&gt;
&lt;!--

## Final Take-away

Table 2 presents the performance (measured as the final step-wise mean reward over the last 2000 episodes – returned during the total of 5e5 training episodes) and training runtime for three variations of the training program:

- PyTorch with Gym (synchronized environments) [available here](https://github.com/engyasin/ilsurvey/blob/main/dqn_hopt.py)
- Flax with Gym (synchronized environments) [available here](https://github.com/engyasin/ilsurvey/blob/main/dqn_hopt_flax.py)
- Flax with Gym and JAX automatic vectorization on GPU [available here](https://github.com/engyasin/ilsurvey/blob/main/dqn_hopt_flax_jax.py)
- PyTorch with Gym and JAX automatic vectorization on GPU [available here](https://github.com/engyasin/ilsurvey/blob/main/dqn_hopt_torch_jax.py)

**Note:** These results were obtained using an **NVIDIA GeForce RTX 5060 Ti** GPU and an **AMD Ryzen 5 7600X 6-Core Processor** for the CPU. Each of the first and last tests was run with 40 trials, while the hyperparameters for the other tests were derived from the best-performing configuration of the final trial.

&lt;center&gt;
&lt;br&gt;

Table 2: Performance and Runtime of Training a DQN Agent to solve the DOORS Environment utilizing three different  combinations of packages (JAX, Flax, and PyTorch)

&lt;br&gt;
&lt;table style="border: 1px solid black" &gt;

&lt;tr &gt;
&lt;th style="border: 1px solid black"&gt;
&lt;/th&gt;
&lt;th style="border: 1px solid black"&gt;
 Pytorch for DQN
&lt;/th&gt;
&lt;th style="border: 1px solid black"&gt;
 FLAX for DQN
&lt;/th&gt;
&lt;th style="border: 1px solid black"&gt;
 FLAX-DQN and JAX for Env
&lt;/th&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="border: 1px solid black"&gt; Rolling Reward&lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; &lt;strong&gt;0.73 &lt;/strong&gt;&lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; 0.64&lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; 0.71 &lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="border: 1px solid black"&gt; Training Time&lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; 22.5 min &lt;/td&gt;
&lt;td style="border: 1px solid black"&gt;  22.8 min &lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; &lt;strong&gt;2.3 min&lt;/strong&gt; &lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="border: 1px solid black"&gt; Training Cruves &lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; &lt;img width="100%"src='/images/rlwithbells/PytorchNumpy.png'&gt; &lt;/td&gt;
&lt;td style="border: 1px solid black"&gt;   &lt;img width="100%"src='/images/rlwithbells/FLAXNumpy.png'&gt; &lt;/td&gt;
&lt;td style="border: 1px solid black"&gt; &lt;strong&gt;&lt;img width="100%"src='/images/rlwithbells/jax_flax.png'&gt;&lt;/strong&gt; &lt;/td&gt;
&lt;/tr&gt;

&lt;/table&gt;
&lt;br&gt;
&lt;/center&gt;



The results in Table 2 indicate that hyperparameter optimization was crucial for achieving strong performance with 
PyTorch, yielding a final reward of 0.73 after 40 trials. The other implementations utilizing JAX and Flax achieved comparable but slightly lower results, potentially due to the inherent randomness associated with initializations in these frameworks. Increasing the number of search trials may yield further improvements across all methods. It is also important to note that in off-policy methods like DQN, a larger buffer size is beneficial for maximizing the speedup gained from saved experiences; otherwise, performance cannot benefit from the fast environment rollout.

The most significant performance gain was observed in training time, attributable to the replacement of standard 
NumPy operations within the DOORS environment with JAX-accelerated, vectorized functional code. This is made possible by increasing the number of environments knowing that the speed of JAX's functional stateless classes is not affected by that increase. Consequently, we leveraged this characteristic by increasing the number of environments by a factor of 16 in the JAX-based implementation, resulting in a substantial speedup on our hardware of approximately 10 times. We anticipate further speedup potential with even bigger number of environments. The remaining settings and hyperparameter ranges were the same across all three tested setups.



With these findings, we conclude by offering recommendations on when and why to utilize each of the discussed packages:

1. **Gymnasium:** If the goal is to create novel environments and facilitate sharing and collaboration with the broader research community, Gymnasium is a suitable choice.

2. **MLflow:** For comprehensive tracking of training metrics and parameters, complete visualization of hyperparameters, and streamlined deployment, MLflow provides a direct and effective solution.

3. **Optuna:** When dealing with complex models possessing numerous hyperparameters that are difficult to tune manually (a common scenario in Reinforcement Learning), Optuna offers implementations of advanced hyperparameter search algorithms with seamless integration with MLflow.

4. **JAX:** If environment simulation is computationally expensive and represents a bottleneck in training runtime, vectorizing the environment using JAX on GPU or TPU devices can yield significant speedups, enabling faster sampling 
of larger batches.

5. **Flax:** As a JAX-based library, Flax benefits from accelerated gradient calculations, potentially leading to performance gains on specialized hardware. However, this benefit may be diminished for smaller models and datasets, as observed in our results where PyTorch performance was comparable. Flax is particularly advantageous when dealing with large observation spaces, such as those containing images or videos with numerous trainable parameters.

Therefore, a thorough examination of the training pipeline is recommended to identify the computational bottleneck, especially in model-free Reinforcement Learning, which involves rollout generation and model parameter updates phases. For the former, we suggest leveraging accelerated JAX matrix operations, and for the latter, we recommend Flax's autodiff and optimizer capabilities.

## Additional Libraries Leveraging JAX

To avoid reinventing the wheel when wanting to switch to JAX implementation, it is useful to explore open-source clean JAX projects for Reinforcement Learning or Environment Simulation, that can be imported and modified as needed. 

### Brax

[Brax](https://github.com/google/brax) [22], a JAX-based reimplementation of MuJoCo developed by Google, demonstrates significant speedups over standard MuJoCo and includes implementations of SAC and PPO RL algorithms.  (For an 
introduction to MuJoCo, see [our post here](https://www.rlbyexample.net/posts/immerse-yourself-in-reinforcement-learning-and-robotics-with-mujoco/)).

### Dopamine

[Dopamine](https://github.com/google/dopamine) [23], another Google-developed package, provides a JAX implementation of a variety of RL algorithms for researchers, facilitating rapid training and testing across diverse environments.




# References

1. Brockman, G., Cheung, V., Pettersson, L., Schneider, J., Schulman, J., Tang, J., &amp; Zaremba, W. (2016). Openai gym. arXiv preprint arXiv:1606.01540.
2. Towers, M., Kwiatkowski, A., Terry, J., Balis, J. U., De Cola, G., Deleu, T., ... &amp; Younis, O. G. (2024). Gymnasium: A standard interface for reinforcement learning environments. arXiv preprint arXiv:2407.17032.
3. Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D., &amp; Riedmiller, M. (2013). Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602.
4. Huang, S., Dossa, R. F. J., Ye, C., Braga, J., Chakraborty, D., Mehta, K., &amp; AraÃšjo, J. G. (2022). Cleanrl: High-quality single-file implementations of deep reinforcement learning algorithms. Journal of Machine Learning Research, 23(274), 1-18.
5. Zaharia, M., Chen, A., Davidson, A., Ghodsi, A., Hong, S. A., Konwinski, A., ... &amp; Zumar, C. (2018). Accelerating the machine learning lifecycle with MLflow. IEEE Data Eng. Bull., 41(4), 39-45.
6. Akiba, T., Sano, S., Yanase, T., Ohta, T., &amp; Koyama, M. (2019, July). Optuna: A next-generation hyperparameter optimization framework. In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery &amp; data mining (pp. 2623-2631).
SMBO
7. Frazier, P. I. (2018). A tutorial on Bayesian optimization. arXiv preprint arXiv:1807.02811.
TPE
8. Bergstra, J., Bardenet, R., Bengio, Y., &amp; Kégl, B. (2011). Algorithms for hyper-parameter optimization. Advances in neural information processing systems, 24.
SMAC
9. Hutter, F., Hoos, H. H., &amp; Leyton-Brown, K. (2011, January). Sequential model-based optimization for general algorithm configuration. In International conference on learning and intelligent optimization (pp. 507-523). Berlin, Heidelberg: Springer Berlin Heidelberg.
METIS
10. Li, Z. L., Liang, C. J. M., He, W., Zhu, L., Dai, W., Jiang, J., &amp; Sun, G. (2018). Metis: Robustly tuning tail latencies of cloud systems. In 2018 USENIX Annual Technical Conference (USENIX ATC 18) (pp. 981-992).
SA
11. Kirkpatrick, S., Gelatt Jr, C. D., &amp; Vecchi, M. P. (1983). Optimization by simulated annealing. science, 220(4598), 671-680.
GA
12. Di Francescomarino, C., Dumas, M., Federici, M., Ghidini, C., Maggi, F. M., Rizzi, W., &amp; Simonetto, L. (2018). Genetic algorithms for hyperparameter optimization in predictive business process monitoring. Information Systems, 74, 67-83.
swarm 
43. Kennedy, J., &amp; Eberhart, R. (1995, November). Particle swarm optimization. In Proceedings of ICNN'95-international conference on neural networks (Vol. 4, pp. 1942-1948). ieee.
population
14. Jaderberg, M., Dalibard, V., Osindero, S., Czarnecki, W. M., Donahue, J., Razavi, A., ... &amp; Kavukcuoglu, K. (2017). Population based training of neural networks. arXiv preprint arXiv:1711.09846.
SH
15. Pietruszka, M., Borchmann, Ł., &amp; Graliński, F. (2021, May). Successive halving top-k operator. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 35, No. 18, pp. 15869-15870).
HB
16. Li, L., Jamieson, K., DeSalvo, G., Rostamizadeh, A., &amp; Talwalkar, A. (2018). Hyperband: A novel bandit-based approach to hyperparameter optimization. Journal of Machine Learning Research, 18(185), 1-52.
BOHB
17. Falkner, S., Klein, A., &amp; Hutter, F. (2018, July). BOHB: Robust and efficient hyperparameter optimization at scale. In International conference on machine learning (pp. 1437-1446). PMLR.
18. Imambi, S., Prakash, K. B., &amp; Kanagachidambaresan, G. R. (2021). PyTorch. In Programming with TensorFlow: solution for edge computing applications (pp. 87-104). Cham: Springer International Publishing.
19. Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., ... &amp; Zhang, Q. (2021). JAX: Autograd and xla. Astrophysics Source Code Library, ascl-2111.
20. Heek, J., Levskaya, A., Oliver, A., Ritter, M., Rondepierre, B., Steiner, A., &amp; Van Zee, M. (2020). Flax: A neural network library and ecosystem for JAX. Version 0.3, 3, 14-26.
21. DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio (2020). The DeepMind JAX Ecosystem https://github.com/google-deepmind
22. Freeman, C. D., Frey, E., Raichuk, A., Girgin, S., Mordatch, I., &amp; Bachem, O. (2021). Brax--a differentiable physics engine for large scale rigid body simulation. arXiv preprint arXiv:2106.13281.
23. Castro, P. S., Moitra, S., Gelada, C., Kumar, S., &amp; Bellemare, M. G. (2018). Dopamine: A research framework for deep reinforcement learning. arXiv preprint arXiv:1812.06110.
24. Sutton, R. S., &amp; Barto, A. G. (1998). Reinforcement learning: An introduction (Vol. 1, No. 1, pp. 9-11). Cambridge: MIT press.


&lt;!-- TODO: Turn Code into images --&gt;
&lt;!-- TODO: Make other images --&gt;
&lt;!-- TODO: Add experiments results --&gt;
&lt;!-- TODO: Mention the hardware --&gt;
&lt;!-- TODO: Add references --&gt;

&lt;!-- TODO: Edit by yourself then LLM then yourself --&gt;
&lt;!-- TODO: More on Tree-structured Parzen Estimators --&gt;</description><category>flax</category><category>gym</category><category>hyperparameter optimization</category><category>jax</category><category>logging</category><category>reinforcement learning</category><guid>https://engyasin.github.io/posts/reinforcement-learning-with-bells-and-whistles/</guid><pubDate>Sun, 17 Aug 2025 17:34:35 GMT</pubDate></item></channel></rss>