decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/output_norm/scale: constant(1.0)