batch_axis_names: 'data'
checkpointer.gc_loop_interval_seconds: 60
checkpointer.keep_every_n_steps: 12510
checkpointer.keep_last_n: 1
checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer'
checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_policy'
checkpointer.save_policy.min_step: 1
checkpointer.save_policy.n: 12510
checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage'
checkpointer.storage.timeout_secs: 3600
evalers['eval_train'].eval_dtype: 'jax.numpy.bfloat16'
evalers['eval_train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy'
evalers['eval_train'].eval_policy.min_step: 1
evalers['eval_train'].eval_policy.n: 12510
evalers['eval_train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch'
evalers['eval_train'].input.batcher.global_batch_size: 80
evalers['eval_train'].input.batcher.pad_example_fn: 'axlearn.vision.input_image.pad_with_negative_labels'
evalers['eval_train'].input.klass: 'axlearn.vision.input_image.ImagenetInput'
evalers['eval_train'].input.processor.fn: 'axlearn.vision.input_image._process_example'
evalers['eval_train'].input.processor.image_size[0]: 224
evalers['eval_train'].input.processor.image_size[1]: 224
evalers['eval_train'].input.processor.input_key: 'image'
evalers['eval_train'].input.processor.num_parallel_calls: 1024
evalers['eval_train'].input.processor.randaug_magnitude: 10
evalers['eval_train'].input.processor.randaug_num_layers: 2
evalers['eval_train'].input.source.dataset_name: 'imagenet2012'
evalers['eval_train'].input.source.download: False
evalers['eval_train'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset'
evalers['eval_train'].input.source.read_config.decode_parallelism: 128
evalers['eval_train'].input.source.read_config.fn: 'axlearn.common.input_tf_data.tfds_read_config'
evalers['eval_train'].input.source.read_config.read_parallelism: 1
evalers['eval_train'].input.source.split: 'train[:50000]'
evalers['eval_train'].input.source.train_shuffle_buffer_size: 0
evalers['eval_train'].input.source.train_shuffle_files: False
evalers['eval_train'].klass: 'axlearn.common.evaler.SpmdEvaler'
evalers['eval_train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator'
evalers['eval_train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator'
evalers['eval_train'].metric_calculator.model_method: 'forward'
evalers['eval_train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter'
evalers['eval_train'].summary_writer.write_every_n_steps: 1
evalers['eval_validation'].eval_dtype: 'jax.numpy.bfloat16'
evalers['eval_validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy'
evalers['eval_validation'].eval_policy.min_step: 1
evalers['eval_validation'].eval_policy.n: 12510
evalers['eval_validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch'
evalers['eval_validation'].input.batcher.global_batch_size: 80
evalers['eval_validation'].input.batcher.pad_example_fn: 'axlearn.vision.input_image.pad_with_negative_labels'
evalers['eval_validation'].input.klass: 'axlearn.vision.input_image.ImagenetInput'
evalers['eval_validation'].input.processor.fn: 'axlearn.vision.input_image._process_example'
evalers['eval_validation'].input.processor.image_size[0]: 224
evalers['eval_validation'].input.processor.image_size[1]: 224
evalers['eval_validation'].input.processor.input_key: 'image'
evalers['eval_validation'].input.processor.num_parallel_calls: 1024
evalers['eval_validation'].input.processor.randaug_magnitude: 10
evalers['eval_validation'].input.processor.randaug_num_layers: 2
evalers['eval_validation'].input.source.dataset_name: 'imagenet2012'
evalers['eval_validation'].input.source.download: False
evalers['eval_validation'].input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset'
evalers['eval_validation'].input.source.read_config.decode_parallelism: 128
evalers['eval_validation'].input.source.read_config.fn: 'axlearn.common.input_tf_data.tfds_read_config'
evalers['eval_validation'].input.source.read_config.read_parallelism: 1
evalers['eval_validation'].input.source.split: 'validation'
evalers['eval_validation'].input.source.train_shuffle_buffer_size: 0
evalers['eval_validation'].input.source.train_shuffle_files: False
evalers['eval_validation'].klass: 'axlearn.common.evaler.SpmdEvaler'
evalers['eval_validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator'
evalers['eval_validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator'
evalers['eval_validation'].metric_calculator.model_method: 'forward'
evalers['eval_validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter'
evalers['eval_validation'].summary_writer.write_every_n_steps: 1
input.batcher.fn: 'axlearn.common.input_tf_data.batch'
input.batcher.global_batch_size: 1024
input.batcher.pad_example_fn: 'axlearn.vision.input_image.pad_with_negative_labels'
input.klass: 'axlearn.vision.input_image.ImagenetInput'
input.processor.fn: 'axlearn.vision.input_image._process_example'
input.processor.image_size[0]: 224
input.processor.image_size[1]: 224
input.processor.input_key: 'image'
input.processor.num_parallel_calls: 1024
input.processor.randaug_magnitude: 10
input.processor.randaug_num_layers: 2
input.source.dataset_name: 'imagenet2012'
input.source.download: False
input.source.fn: 'axlearn.common.input_tf_data.tfds_dataset'
input.source.read_config.decode_parallelism: 128
input.source.read_config.fn: 'axlearn.common.input_tf_data.tfds_read_config'
input.source.read_config.read_parallelism: 1
input.source.split: 'train'
input.source.train_shuffle_buffer_size: 1024
klass: 'axlearn.common.trainer.SpmdTrainer'
learner.ema.fn: 'axlearn.common.optimizers.param_ema'
learner.enable_per_variable_summaries: False
learner.klass: 'axlearn.common.learner.Learner'
learner.optimizer.decouple_weight_decay: False
learner.optimizer.fn: 'axlearn.common.optimizers.sgd_optimizer'
learner.optimizer.learning_rate.alpha: 0.0
learner.optimizer.learning_rate.begin_value: 0.0
learner.optimizer.learning_rate.fn: 'axlearn.common.schedule.cosine_with_linear_warmup'
learner.optimizer.learning_rate.max_step: 112590
learner.optimizer.learning_rate.peak_lr: 0.4
learner.optimizer.learning_rate.warmup_steps: 6255
learner.optimizer.momentum: 0.9
learner.optimizer.weight_decay: 0.0001
learner.optimizer.weight_decay_per_param_scale.default_scale: 1.0
learner.optimizer.weight_decay_per_param_scale.description: 'weight_decay_scale'
learner.optimizer.weight_decay_per_param_scale.fn: 'axlearn.common.optimizers.per_param_scale_by_path'
learner.optimizer.weight_decay_per_param_scale.scale_by_path[0][0]: '.*norm.*'
learner.optimizer.weight_decay_per_param_scale.scale_by_path[0][1]: 0
max_step: 112590
model.backbone.dtype: 'jax.numpy.float32'
model.backbone.hidden_dim: 64
model.backbone.klass: 'axlearn.vision.resnet.ResNet'
model.backbone.num_blocks_per_stage[0]: 3
model.backbone.num_blocks_per_stage[1]: 4
model.backbone.num_blocks_per_stage[2]: 6
model.backbone.num_blocks_per_stage[3]: 3
model.backbone.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
model.backbone.param_init.init_by_param_name['.*weight$'].fan: 'fan_out'
model.backbone.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730951
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
model.backbone.stage.block.conv.padding[0][1]: 1
model.backbone.stage.block.conv.padding[1][0]: 1
model.backbone.stage.block.conv.padding[1][1]: 1
model.backbone.stage.block.conv.param_partition_spec[0]: None
model.backbone.stage.block.conv.param_partition_spec[1]: None
model.backbone.stage.block.conv.param_partition_spec[2]: None
model.backbone.stage.block.conv.param_partition_spec[3]: 'model'
model.backbone.stage.block.conv.strides[0]: 1
model.backbone.stage.block.conv.strides[1]: 1
model.backbone.stage.block.conv.window[0]: 3
model.backbone.stage.block.conv.window[1]: 3
model.backbone.stage.block.downsample.downsample_op: 'conv'
model.backbone.stage.block.downsample.klass: 'axlearn.vision.resnet.Downsample'
model.backbone.stage.block.downsample.norm.decay: 0.9
model.backbone.stage.block.downsample.norm.eps: 1e-05
model.backbone.stage.block.downsample.norm.forward_dtype: 'jax.numpy.float32'
model.backbone.stage.block.downsample.norm.klass: 'axlearn.common.layers.BatchNorm'
model.backbone.stage.block.downsample.param_partition_spec[0]: None
model.backbone.stage.block.downsample.param_partition_spec[1]: None
model.backbone.stage.block.downsample.param_partition_spec[2]: None
model.backbone.stage.block.downsample.param_partition_spec[3]: 'model'
model.backbone.stage.block.downsample.stride: 1
model.backbone.stage.block.klass: 'axlearn.vision.resnet.Bottleneck'
model.backbone.stage.block.norm.decay: 0.9
model.backbone.stage.block.norm.eps: 1e-05
model.backbone.stage.block.norm.forward_dtype: 'jax.numpy.float32'
model.backbone.stage.block.norm.klass: 'axlearn.common.layers.BatchNorm'
model.backbone.stage.block.squeeze_excitation.activation: 'nn.relu'
model.backbone.stage.block.squeeze_excitation.gating: 'nn.sigmoid'
model.backbone.stage.block.squeeze_excitation.klass: 'axlearn.common.layers.SqueezeExcitation'
model.backbone.stage.block.squeeze_excitation.param_partition_spec[0]: None
model.backbone.stage.block.squeeze_excitation.param_partition_spec[1]: None
model.backbone.stage.block.squeeze_excitation.param_partition_spec[2]: None
model.backbone.stage.block.squeeze_excitation.param_partition_spec[3]: 'model'
model.backbone.stage.block.squeeze_excitation.se_ratio: 0.0
model.backbone.stage.block.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
model.backbone.stage.block.stochastic_depth.mode: 'row'
model.backbone.stage.block.stride: 1
model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
model.backbone.stem.conv.padding[0][1]: 3
model.backbone.stem.conv.padding[1][0]: 3
model.backbone.stem.conv.padding[1][1]: 3
model.backbone.stem.conv.param_partition_spec[0]: None
model.backbone.stem.conv.param_partition_spec[1]: None
model.backbone.stem.conv.param_partition_spec[2]: None
model.backbone.stem.conv.param_partition_spec[3]: 'model'
model.backbone.stem.conv.strides[0]: 2
model.backbone.stem.conv.strides[1]: 2
model.backbone.stem.conv.window[0]: 7
model.backbone.stem.conv.window[1]: 7
model.backbone.stem.input_dim: 3
model.backbone.stem.klass: 'axlearn.vision.resnet.StemV0'
model.backbone.stem.norm.decay: 0.9
model.backbone.stem.norm.eps: 1e-05
model.backbone.stem.norm.forward_dtype: 'jax.numpy.float32'
model.backbone.stem.norm.klass: 'axlearn.common.layers.BatchNorm'
model.classifier.bias: True
model.classifier.klass: 'axlearn.common.layers.Linear'
model.classifier.param_partition_spec[0]: 'model'
model.classifier.param_partition_spec[1]: None
model.dropout.klass: 'axlearn.common.layers.Dropout'
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.vision.image_classification.ImageClassificationModel'
model.metric.klass: 'axlearn.common.layers.ClassificationMetric'
model.metric.label_smoothing: 0.0
model.num_classes: 1000
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
model.param_init.init_by_param_name['.*weight$'].fan: 'fan_out'
model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730951
model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
name: 'resnet_imagenet'
prune_empty_state_updates: True
save_input_iterator: False
start_trace_process_indices[0]: 0
summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter'
summary_writer.write_every_n_steps: 100
train_dtype: 'jax.numpy.bfloat16'