앞 글에 이어 fairseq 의 examples 의 wav2vec2.0 pretraining 글입니다.
2. wav2vec 2.0 모델 학습시키기
Train a wav2vec 2.0 base model
fairseq-hydra-train \
task.data=/path/to/data \
--config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_base_librispeech
- 위 configuration은 wav2vec 2.0 논문의 Libispeech 데이터 세트에 대해 훈련된 기본 모델
- 입력은 16000 Hz 로 샘플링된 단일 채널이어야 함
- 데이터, 모델 파라미터 설정을 위한 config 정보 필요
--config-dir : config 정보가 있는 폴더
--config-name : config 에 대한 정보가 key, value 로 저장되어 있는 yaml 파일 이름
wav2vec2_base_librispeech.yaml
common:
fp16: true
log_format: json
log_interval: 200
checkpoint:
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: audio_pretraining
data: ???
max_sample_size: 250000
min_sample_size: 32000
normalize: false
dataset:
num_workers: 6
max_tokens: 1400000
skip_invalid_size_inputs_valid_test: true
distributed_training:
distributed_world_size: 64
ddp_backend: legacy_ddp
criterion:
_name: wav2vec
infonce: true
log_keys: ["prob_perplexity","code_perplexity","temp"]
loss_weights: [0.1, 10]
optimization:
max_update: 400000
lr: [0.0005]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: polynomial_decay
warmup_updates: 32000
model:
_name: wav2vec2
quantize_targets: true
final_dim: 256
encoder_layerdrop: 0.05
dropout_input: 0.1
dropout_features: 0.1
feature_grad_mult: 0.1
encoder_embed_dim: 768
- hydra_train.py 이용해 pretraining 진행
hydra_train.py
def cli_main():
try:
from hydra._internal.utils import get_args
cfg_name = get_args().config_name or "config"
except:
logger.warning("Failed to get config name from hydra args")
cfg_name = "config"
hydra_init(cfg_name)
hydra_main()
- cfg_name 은 위 예시에서 'wav2vec2_base_librispeech' (--config-name 의 입력으로 준 파일의 이름을 저장)
- hydra_init(cfg_name) 을 수행하여 파라미터를 설정
- hydra_main() 을 통해 training 시작
hydra_init 함수
def hydra_init(cfg_name="config") -> None:
cs = ConfigStore.instance()
cs.store(name=f"{cfg_name}", node=FairseqConfig)
for k in FairseqConfig.__dataclass_fields__:
v = FairseqConfig.__dataclass_fields__[k].default
try:
cs.store(name=k, node=v)
except BaseException:
logger.error(f"{k} - {v}")
rais
- ConfigStore 에 default 를 등록
- 위 예시로 치면 ConfigStore 에 wav2vec2_base_librispeech 저장
- 즉, 파라미터 설정하는 단계
hydra_train.py
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
def hydra_main(cfg: FairseqConfig) -> float:
_hydra_main(cfg)
- hydra_main() 이 시작되면, 감싸면서 추가기능을 구현하는 데코레이터인 @hydra.main 에 의해 cfg 변수 선언
task:
_name: audio_pretraining
data: ???
- 그럼 cfg 에서 아까 data : ??? 의 부분이 처음에 wav2vec_manifest.py 에서 입력했던 manifest path 로 업데이트 됨
def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
add_defaults(cfg)
if cfg.common.reset_logging:
reset_logging() # Hydra hijacks logging, fix that
else:
# check if directly called or called through hydra_main
if HydraConfig.initialized():
with open_dict(cfg):
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
cfg.job_logging_cfg = OmegaConf.to_container(
HydraConfig.get().job_logging, resolve=True
)
with omegaconf_no_object_check():
cfg = OmegaConf.create(
OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
)
OmegaConf.set_struct(cfg, True)
try:
if cfg.common.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(cfg, pre_main, **kwargs)
else:
distributed_utils.call_main(cfg, pre_main, **kwargs)
except BaseException as e:
if not cfg.common.suppress_crashes:
raise
else:
logger.error("Crashed! " + str(e))
# get best val and return - useful for sweepers
try:
best_val = metrics.get_smoothed_value(
"valid", cfg.checkpoint.best_checkpoint_metric
)
except:
best_val = None
if best_val is None:
best_val = float("inf")
return best_val
sudo vim wav2vec_train.sh
Reference
'Spoken Language Processing' 카테고리의 다른 글
wav2vec2.0 pretrained 모델로 디코딩하기 (0) | 2022.08.17 |
---|---|
End-to-End ASR : Attention vs RNN-T (0) | 2022.08.04 |
Fairseq - Wav2vec 2.0 Pretraining (2) Preprocess 전처리하기 (1) | 2022.06.14 |
딥러닝으로 음향모델 모델링 (End-to-end algorithm) (0) | 2022.06.03 |
음향모델의 모델링 - ASR 의 acoustic model (0) | 2022.06.03 |
댓글