Reinforcement Learning Models
Reinforcement Learning (RL) is a paradigm of machine learning where an "agent" learns to make decisions by performing actions in an "environment" to maximize a cumulative reward.
Unlike supervised learning, RL models are not trained on a fixed dataset. Instead, they are policies and/or value functions that an agent uses to interact with an environment and learn from the feedback it receives.
xTorch provides implementations of several major RL algorithms, encapsulating the underlying neural network architectures (the policies and value functions) that power them.
All RL models are located under the xt::models namespace and their headers can be found in the <xtorch/models/reinforcement_learning/> directory.
General Usage
RL models are used differently from standard supervised models. Instead of a single forward pass on a batch of data, they are typically used within an "agent-environment loop."
- A Policy Network takes the current state (observation) from the environment and outputs a probability distribution over possible actions.
- A Value Network (or Q-Network) takes the current state and outputs an estimated value for each possible action (the expected future reward).
The examples below show how to instantiate and use these two core components.
Example: Using a Q-Network (for DQN)
#include <xtorch/xtorch.h>
#include <iostream>
int main() {
torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
// --- Environment & Model Properties ---
const int num_observations = 4; // E.g., state space size for CartPole
const int num_actions = 2; // E.g., number of possible actions
// --- Instantiate a DQN Model ---
// This is the Q-Network that approximates the action-value function.
xt::models::DQN model(num_observations, num_actions);
model.to(device);
model.eval();
std::cout << "DQN Model (Q-Network) Instantiated." << std::endl;
// --- Get Action Values for a Given State ---
// Create a dummy observation from the environment
auto state = torch::randn({1, num_observations}).to(device); // Batch size of 1
// The model's forward pass returns the Q-value for each action
auto action_values = model.forward(state);
// The agent would then use an epsilon-greedy strategy to select an action
auto best_action = torch::argmax(action_values, /*dim=*/1);
std::cout << "Action values: " << action_values << std::endl;
std::cout << "Best action: " << best_action.item<long>() << std::endl;
}Example: Using a Policy Network (for Actor-Critic)
// --- Instantiate an Actor-Critic Model (like A3C) ---
// Note: A3C often has a shared body with two heads (policy and value).
// xt::models::A3C model(num_observations, num_actions);
// model.to(device);
// auto [policy_logits, value_estimate] = model.forward(state);
// The policy head gives logits, which are converted to a probability distribution
// auto action_probabilities = torch::softmax(policy_logits, /*dim=*/-1);
// torch::distributions::Categorical dist(action_probabilities);
// auto action = dist.sample(); // Sample an action from the policyAvailable Models by Family
Value-Based Methods
These methods learn a value function that estimates the expected return for taking an action in a given state. The policy is often implicit (e.g., "always take the action with the highest value").
| Model | Description | Header File |
|---|---|---|
DQN |
Deep Q-Network. A foundational algorithm that uses a deep neural network to approximate the optimal action-value function, Q*. | dqn.h |
DoubleDQN |
An improvement over DQN that decouples action selection from action evaluation to reduce overestimation of Q-values. | double_dqn.h |
DuelingDQN |
An architecture that separates the estimation of state values and action advantages, leading to better policy evaluation. | dueling_dqn.h |
Rainbow |
A combination of seven improvements to DQN (including Double, Dueling, Prioritized Replay, etc.) into a single, high-performing agent. | rainbow.h |
Policy-Based & Actor-Critic Methods
These methods directly learn a policy that maps states to actions. Actor-Critic methods learn both a policy (the actor) and a value function (the critic) simultaneously.
| Model | Description | Header File |
|---|---|---|
A3C |
Asynchronous Advantage Actor-Critic. A classic parallel RL algorithm. | a3c.h |
PPO |
Proximal Policy Optimization. A highly effective and stable actor-critic method, often a default choice for many continuous control problems. | pro.h |
DDPG |
Deep Deterministic Policy Gradient. An actor-critic, model-free algorithm for continuous action spaces. | ddpg.h |
TD3 |
Twin Delayed DDPG. An improvement over DDPG that addresses Q-value overestimation by using two critic networks. | td3.h |
SAC |
Soft Actor-Critic. An off-policy actor-critic algorithm based on the maximum entropy framework, known for its sample efficiency and stability. | sac.h |
Model-Based & Planning Methods
These methods learn a model of the environment and use it to plan future actions.
| Model | Description | Header File |
|---|---|---|
AlphaGo |
The pioneering deep RL program that defeated the world champion Go player, combining Monte Carlo tree search with deep neural networks. | alpha_go.h |
AlphaZero |
A more generalized and powerful version of AlphaGo that learns entirely from self-play and mastered Go, chess, and shogi. | alpha_zero.h |
MuZero |
A powerful successor to AlphaZero that achieves superhuman performance by learning a model of the environment and applying tree-based search. | mu_zero.h |
Other Architectures
| Model | Description | Header File |
|---|---|---|
IMPALA |
Importance Weighted Actor-Learner Architecture. A scalable, distributed agent that can be used for both single and multi-task reinforcement learning. | impala.h |
