batch_axis_names[0]: 'data'
batch_axis_names[1]: 'expert'
batch_axis_names[2]: 'fsdp'
batch_axis_names[3]: 'seq'
checkpointer.gc_loop_interval_seconds: 60
checkpointer.keep_every_n_steps: 50000
checkpointer.keep_last_n: 3
checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer'
checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_policy'
checkpointer.save_policy.min_step: 1
checkpointer.save_policy.n: 5000
checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage'
checkpointer.storage.timeout_secs: 3600
evalers['train'].eval_dtype: 'jax.numpy.bfloat16'
evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy'
evalers['train'].eval_policy.min_step: 1
evalers['train'].eval_policy.n: 5000
evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch'
evalers['train'].input.batcher.global_batch_size: 32
evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn'
evalers['train'].input.batcher.prefetch_buffer_size: -1
evalers['train'].input.is_training: False
evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input'
evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity'
evalers['train'].input.source.dataset_name: 'c4/en:3.0.1'
evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input'
evalers['train'].input.source.is_training: False
evalers['train'].input.source.max_sequence_length: 2048
evalers['train'].input.source.replace_newlines_with: '\n'
evalers['train'].input.source.split: 'train[:8192]'
evalers['train'].input.source.train_shuffle_buffer_size: 16384
evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab'
evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model'
evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler'
evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator'
evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator'
evalers['train'].metric_calculator.model_method: 'forward'
evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter'
evalers['train'].summary_writer.write_every_n_steps: 1
evalers['validation'].eval_dtype: 'jax.numpy.bfloat16'
evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy'
evalers['validation'].eval_policy.min_step: 1
evalers['validation'].eval_policy.n: 5000
evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch'
evalers['validation'].input.batcher.global_batch_size: 32
evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn'
evalers['validation'].input.batcher.prefetch_buffer_size: -1
evalers['validation'].input.is_training: False
evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input'
evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity'
evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1'
evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input'
evalers['validation'].input.source.is_training: False
evalers['validation'].input.source.max_sequence_length: 2048
evalers['validation'].input.source.replace_newlines_with: '\n'
evalers['validation'].input.source.split: 'validation'
evalers['validation'].input.source.train_shuffle_buffer_size: 16384
evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab'
evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model'
evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler'
evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator'
evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator'
evalers['validation'].metric_calculator.model_method: 'forward'
evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter'
evalers['validation'].summary_writer.write_every_n_steps: 1
input.batcher.fn: 'axlearn.common.input_tf_data.batch'
input.batcher.global_batch_size: 32
input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn'
input.batcher.prefetch_buffer_size: -1
input.is_training: True
input.klass: 'axlearn.common.input_tf_data.Input'
input.processor.fn: 'axlearn.common.input_tf_data.identity'
input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1'
input.source.data_mixture_components[0]['weight']: 1.0
input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192
input.source.data_mixture_components[0]['split']: 'train'
input.source.data_mixture_components[0]['info']: ''
input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source'
input.source.max_sequence_length: 2048
input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor'
input.source.preprocessor.max_padding_fraction: 0.5
input.source.preprocessor.shuffle_buffer_size: 8192
input.source.preprocessor.window_size: 128
input.source.replace_newlines_with: '<n>'
input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab'
input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model'
klass: 'axlearn.common.trainer.SpmdTrainer'
learner.ema.fn: 'axlearn.common.optimizers.param_ema'
learner.enable_per_variable_summaries: False
learner.klass: 'axlearn.common.learner.Learner'
learner.optimizer.args[0].eps: 1e-08
learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm'
learner.optimizer.args[0].max_norm: 1
learner.optimizer.args[1].b1: 0.9
learner.optimizer.args[1].b2: 0.95
learner.optimizer.args[1].eps: 1e-08
learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer'
learner.optimizer.args[1].learning_rate: 0.0003
learner.optimizer.args[1].update_schedule.alpha: 0.1
learner.optimizer.args[1].update_schedule.begin_value: 0.0
learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup'
learner.optimizer.args[1].update_schedule.max_step: 500000
learner.optimizer.args[1].update_schedule.peak_lr: 1.0
learner.optimizer.args[1].update_schedule.warmup_steps: 2000
learner.optimizer.args[1].weight_decay: 0.1
learner.optimizer.fn: 'axlearn.common.optimizers.chain'
max_step: 500000
mesh_axis_names[0]: 'data'
mesh_axis_names[1]: 'expert'
mesh_axis_names[2]: 'fsdp'
mesh_axis_names[3]: 'seq'
mesh_axis_names[4]: 'model'
mesh_rules[0][0]: 'tpu-v4-(1024|2048)'
mesh_rules[0][1][0]: -1
mesh_rules[0][1][1]: 1
mesh_rules[0][1][2]: 16
mesh_rules[0][1][3]: 1
mesh_rules[0][1][4]: 1
mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)'
mesh_rules[1][1][0]: -1
mesh_rules[1][1][1]: 1
mesh_rules[1][1][2]: 8
mesh_rules[1][1][3]: 1
mesh_rules[1][1][4]: 1
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: -1
mesh_shape[3]: 1
mesh_shape[4]: 1
model.batch_axis_names[0]: 'data'
model.batch_axis_names[1]: 'expert'
model.batch_axis_names[2]: 'fsdp'
model.decoder.attention_mask.klass: 'axlearn.common.attention.CausalAttentionLogitBiasLayer'
model.decoder.dim: 4096
model.decoder.dropout_rate: 0.0
model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings'
model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0
model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.decoder.emb.token_emb.param_partition_spec[0]: None
model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
model.decoder.eos_token_id: 1
model.decoder.klass: 'axlearn.common.decoder.Decoder'
model.decoder.logits_partition_spec[0][0]: 'data'
model.decoder.logits_partition_spec[0][1]: 'expert'
model.decoder.logits_partition_spec[0][2]: 'fsdp'
model.decoder.logits_partition_spec[1]: 'seq'
model.decoder.logits_partition_spec[2]: 'model'
model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.output_norm.eps: 1e-05
model.decoder.output_norm.forward_dtype: None
model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.pad_token_id: 0
model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer'
model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu'
model.decoder.transformer.layer.feed_forward.activation[1]: 'linear'
model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn'
model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256
model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665
model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer'
model.decoder.transformer.layer.feed_forward.linear1.bias: False
model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1][0]: 'seq'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.bias: False
model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1][0]: 'seq'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq'
model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05
model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None
model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.transformer.layer.feed_forward.residual_weight: 1.0
model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None
model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding'
model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0
model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False
model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey'
model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention'
model.decoder.transformer.layer.self_attention.attention.num_heads: 32
model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False
model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None
model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery'
model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer'
model.decoder.transformer.layer.self_attention.norm.eps: 1e-05
model.decoder.transformer.layer.self_attention.norm.forward_dtype: None
model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.self_attention.structure: 'prenorm'
model.decoder.transformer.num_layers: 32
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 32768
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in'
model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.param_init.init_by_param_name['.*weight$'].scale: 1.0
model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.seq_axis_names: 'seq'
model.z_loss_scale: 0.0
name: 'gpt_trainer'
prune_empty_state_updates: True
save_input_iterator: False
start_trace_process_indices[0]: 0
summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter'
summary_writer.write_every_n_steps: 100
train_dtype: 'jax.numpy.bfloat16'