Skip to main content

6.3 - Loading and Running a Trained Agent

Training a model is only half the process. Once you have a trained agent saved to a file, you need a way to load and run it for evaluation, a process known as inference. This allows you to watch your agent perform, test its capabilities against specific opponents, and verify its learned policy without the overhead of retraining.

This section provides a dedicated, reusable script for running inference on your saved models.

The Inference Workflow

Unlike the training loop, the inference loop does not learn or update the model. It simply uses the model's existing policy to choose actions.

+--------------------------+
| 1. Select & Load Model |
+--------------------------+
|
v
+--------------------------+
| 2. Reset Environment |
+--------------------------+
|
v
+--------------------------+
| 3. Get Observation |
+--------------------------+
|
| +-------------------+
| | 4. Predict Action |
+------>| (model.predict) |
+-------------------+
<------+
|
| 5. Step Environment
|
v
+--------------------------+
| 6. Repeat until Done |
+--------------------------+

Using the Learned Policy- model.predict()

The core of the inference loop is the model.predict() method.

ParameterValuePurpose
deterministicTrue(For Evaluation) Instructs the model to always choose the action with the highest probability. This reflects the agent's "true" learned policy and makes its behavior consistent.
deterministicFalse(For Training) Allows the model to sample from its action probabilities. This encourages exploration, which is necessary for learning but not for evaluation.

Code- The Reusable run_agent.py Script

This script is your entry point for evaluating any of the agents you have trained.

# run_agent.py

import multiprocessing as mp
import time

from stable_baselines3 import PPO

# Import all possible environments
from worker_bot import WorkerEnv
from macro_bot import MacroEnv
from micro_bot import MicroEnv

# --- Configuration ---
# Use a dictionary to map names to environment classes and model paths.
AGENT_CONFIG = {
"worker": {
"env": WorkerEnv,
"path": "./models/ppo_sc2_worker.zip"
},
"macro": {
"env": MacroEnv,
"path": "./models/ppo_sc2_macro.zip"
},
"micro": {
"env": MicroEnv,
"path": "./models/ppo_sc2_micro.zip"
}
}
# CHANGE THIS VALUE to select which agent to run
SELECTED_AGENT = "micro"

def main():
"""Loads and runs a pre-trained agent for one episode."""

config = AGENT_CONFIG[SELECTED_AGENT]
print(f"--- Running evaluation for agent: {SELECTED_AGENT.upper()} ---")

# --- Step 1: Instantiate the Environment ---
env = config["env"]()

# --- Step 2: Load the Trained Model ---
try:
model = PPO.load(config["path"])
print(f"Model loaded successfully from: {config['path']}")
except FileNotFoundError:
print(f"ERROR: Model file not found at {config['path']}")
print("Please run a training script to train and save the model first.")
env.close()
return

# --- Step 3: Run the Inference Loop ---
obs, info = env.reset()
terminated = False
truncated = False

while not (terminated or truncated):
# Use the model to predict the next action.
# deterministic=True ensures the agent plays its best learned policy.
action, _states = model.predict(obs, deterministic=True)

# Take the action in the environment.
obs, reward, terminated, truncated, info = env.step(action)

# Optional: Add a small delay to make it easier to watch.
time.sleep(0.01)

if truncated:
print("Episode finished due to reaching a limit (truncated).")
else:
print("Episode finished normally (terminated).")

# --- Final Step: Clean up ---
env.close()


if __name__ == '__main__':
# Add the multiprocessing guard for Windows compatibility when creating executables.
mp.freeze_support()
main()