class DCRLMEEmitter(MultiEmitter):
def __init__(
self,
config: DCRLMEConfig,
policy_network: nn.Module,
actor_network: nn.Module,
env: QDEnv,
variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]],
selector: Optional[Selector] = None,
) -> None:
self._config = config
self._env = env
self._variation_fn = variation_fn
dcrl_config = DCRLConfig(
dcrl_batch_size=config.dcrl_batch_size,
ai_batch_size=config.ai_batch_size,
lengthscale=config.lengthscale,
critic_hidden_layer_size=config.critic_hidden_layer_size,
num_critic_training_steps=config.num_critic_training_steps,
num_pg_training_steps=config.num_pg_training_steps,
batch_size=config.batch_size,
replay_buffer_size=config.replay_buffer_size,
discount=config.discount,
reward_scaling=config.reward_scaling,
critic_learning_rate=config.critic_learning_rate,
actor_learning_rate=config.actor_learning_rate,
policy_learning_rate=config.policy_learning_rate,
noise_clip=config.noise_clip,
policy_noise=config.policy_noise,
soft_tau_update=config.soft_tau_update,
policy_delay=config.policy_delay,
)
# define the quality emitter
dcrl_emitter = DCRLEmitter(
config=dcrl_config,
policy_network=policy_network,
actor_network=actor_network,
env=env,
selector=selector,
)
# define the GA emitter
ga_emitter = MixingEmitter(
mutation_fn=lambda x, r: (x, r),
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=config.ga_batch_size,
selector=selector,
)
super().__init__(emitters=(dcrl_emitter, ga_emitter))