본문 바로가기
Spoken Language Processing

Fairseq - Wav2vec 2.0 Pretraining (3) pretraining 시키기

by 햇농nongnong 2022. 6. 14.

앞 글에 이어 fairseq 의 examples 의 wav2vec2.0 pretraining 글입니다.

 

 

2. wav2vec 2.0 모델 학습시키기

 

Train a wav2vec 2.0 base model
fairseq-hydra-train \
	task.data=/path/to/data \
    --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
    --config-name wav2vec2_base_librispeech

 

  • 위 configuration은 wav2vec 2.0 논문의 Libispeech 데이터 세트에 대해 훈련된 기본 모델
  • 입력은 16000 Hz 로 샘플링된 단일 채널이어야 함
  • 데이터, 모델 파라미터 설정을 위한 config 정보 필요
    --config-dir : config 정보가 있는 폴더
    --config-name : config 에 대한 정보가 key, value 로 저장되어 있는 yaml 파일 이름

wav2vec2_base_librispeech.yaml
common:
  fp16: true
  log_format: json
  log_interval: 200

checkpoint:
  save_interval_updates: 25000
  keep_interval_updates: 1
  no_epoch_checkpoints: true

task:
  _name: audio_pretraining
  data: ???
  max_sample_size: 250000
  min_sample_size: 32000
  normalize: false

dataset:
  num_workers: 6
  max_tokens: 1400000
  skip_invalid_size_inputs_valid_test: true

distributed_training:
  distributed_world_size: 64
  ddp_backend: legacy_ddp

criterion:
  _name: wav2vec
  infonce: true
  log_keys: ["prob_perplexity","code_perplexity","temp"]
  loss_weights: [0.1, 10]

optimization:
  max_update: 400000
  lr: [0.0005]

optimizer:
  _name: adam
  adam_betas: (0.9,0.98)
  adam_eps: 1e-06
  weight_decay: 0.01

lr_scheduler:
  _name: polynomial_decay
  warmup_updates: 32000

model:
  _name: wav2vec2
  quantize_targets: true
  final_dim: 256
  encoder_layerdrop: 0.05
  dropout_input: 0.1
  dropout_features: 0.1
  feature_grad_mult: 0.1
  encoder_embed_dim: 768

 

  • hydra_train.py  이용해 pretraining 진행
hydra_train.py
def cli_main():
    try:
        from hydra._internal.utils import get_args

        cfg_name = get_args().config_name or "config"
    except:
        logger.warning("Failed to get config name from hydra args")
        cfg_name = "config"

    hydra_init(cfg_name)
    hydra_main()
  • cfg_name 은 위 예시에서 'wav2vec2_base_librispeech' (--config-name 의 입력으로 준 파일의 이름을 저장)
  • hydra_init(cfg_name) 을 수행하여 파라미터를 설정
  • hydra_main() 을 통해 training 시작

 

hydra_init 함수
def hydra_init(cfg_name="config") -> None:

    cs = ConfigStore.instance()
    cs.store(name=f"{cfg_name}", node=FairseqConfig)

    for k in FairseqConfig.__dataclass_fields__:
        v = FairseqConfig.__dataclass_fields__[k].default
        try:
            cs.store(name=k, node=v)
        except BaseException:
            logger.error(f"{k} - {v}")
            rais
  • ConfigStore 에 default 를 등록
  • 위 예시로 치면 ConfigStore 에 wav2vec2_base_librispeech 저장
  • 즉, 파라미터 설정하는 단계

 

hydra_train.py
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
def hydra_main(cfg: FairseqConfig) -> float:
    _hydra_main(cfg)
  • hydra_main() 이 시작되면, 감싸면서 추가기능을 구현하는 데코레이터인 @hydra.main 에 의해 cfg 변수 선언
task:
    _name: audio_pretraining
    data: ???

- 그럼 cfg 에서 아까 data : ??? 의 부분이 처음에 wav2vec_manifest.py 에서 입력했던 manifest path 로 업데이트 됨

 

def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
    add_defaults(cfg)

    if cfg.common.reset_logging:
        reset_logging()  # Hydra hijacks logging, fix that
    else:
        # check if directly called or called through hydra_main
        if HydraConfig.initialized():
            with open_dict(cfg):
                # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
                cfg.job_logging_cfg = OmegaConf.to_container(
                    HydraConfig.get().job_logging, resolve=True
                )

    with omegaconf_no_object_check():
        cfg = OmegaConf.create(
            OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
        )
    OmegaConf.set_struct(cfg, True)

    try:
        if cfg.common.profile:
            with torch.cuda.profiler.profile():
                with torch.autograd.profiler.emit_nvtx():
                    distributed_utils.call_main(cfg, pre_main, **kwargs)
        else:
            distributed_utils.call_main(cfg, pre_main, **kwargs)
    except BaseException as e:
        if not cfg.common.suppress_crashes:
            raise
        else:
            logger.error("Crashed! " + str(e))

    # get best val and return - useful for sweepers
    try:
        best_val = metrics.get_smoothed_value(
            "valid", cfg.checkpoint.best_checkpoint_metric
        )
    except:
        best_val = None

    if best_val is None:
        best_val = float("inf")

    return best_val

 

sudo vim wav2vec_train.sh

 

 Reference

댓글