Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME)

To create an instance of DCRL-ME, one need to use an instance of MAP-Elites with the DCRLMEEmitter, detailed below.

Bases: MultiEmitter

Source code in qdax/core/emitters/dcrl_me_emitter.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class DCRLMEEmitter(MultiEmitter):
    def __init__(
        self,
        config: DCRLMEConfig,
        policy_network: nn.Module,
        actor_network: nn.Module,
        env: QDEnv,
        variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]],
        selector: Optional[Selector] = None,
    ) -> None:
        self._config = config
        self._env = env
        self._variation_fn = variation_fn

        dcrl_config = DCRLConfig(
            dcrl_batch_size=config.dcrl_batch_size,
            ai_batch_size=config.ai_batch_size,
            lengthscale=config.lengthscale,
            critic_hidden_layer_size=config.critic_hidden_layer_size,
            num_critic_training_steps=config.num_critic_training_steps,
            num_pg_training_steps=config.num_pg_training_steps,
            batch_size=config.batch_size,
            replay_buffer_size=config.replay_buffer_size,
            discount=config.discount,
            reward_scaling=config.reward_scaling,
            critic_learning_rate=config.critic_learning_rate,
            actor_learning_rate=config.actor_learning_rate,
            policy_learning_rate=config.policy_learning_rate,
            noise_clip=config.noise_clip,
            policy_noise=config.policy_noise,
            soft_tau_update=config.soft_tau_update,
            policy_delay=config.policy_delay,
        )

        # define the quality emitter
        dcrl_emitter = DCRLEmitter(
            config=dcrl_config,
            policy_network=policy_network,
            actor_network=actor_network,
            env=env,
            selector=selector,
        )

        # define the GA emitter
        ga_emitter = MixingEmitter(
            mutation_fn=lambda x, r: (x, r),
            variation_fn=variation_fn,
            variation_percentage=1.0,
            batch_size=config.ga_batch_size,
            selector=selector,
        )

        super().__init__(emitters=(dcrl_emitter, ga_emitter))