pyFDN.train_fdn#
- pyFDN.train_fdn(model, mode, *, target=None, criteria=None, sparsity_alpha=0.2, mss_nfft=(256, 512, 1024), max_steps=2000, lr=0.001, optimizer='adam', patience=10, tol=1e-06, device=None, dtype=None, rng=None, log=False, train_dir=None)[source]#
Train
modelformodein place and return aTrainLog.Read the trained result back with
pyFDN.extract_build().- Parameters:
model (flamo Shell) – A trainable model from
pyFDN.build_fdn()/trainable_from_build.mode (str) –
"colorless","match_spectrogram"or"match_mel_spectrogram"."colorless"is single-input/single-output only.target (np.ndarray, optional) – Reference impulse response for the matching modes (unused for
colorless). Shape(n_samples,)or(n_samples, n_out), or a 3-D(n_samples, n_out, n_in)IR matrix to fit a full MIMO system.criteria (list of (criterion, alpha, requires_model), optional) – Replace the default loss list (primary loss + sparsity) with your own.
sparsity_alpha (float) – Weight of the feedback-matrix sparsity penalty (default 0.2; 0 disables).
mss_nfft (tuple of int) – STFT window sizes for the spectrogram modes.
max_steps (max gradient steps, learning rate, plateau patience.)
lr (max gradient steps, learning rate, plateau patience.)
patience (max gradient steps, learning rate, plateau patience.)
optimizer (str) –
"adam"(default) or"lbfgs".tol (float) – Relative-improvement threshold for the plateau early stop.
device (optional) – Torch device / dtype (default cpu / float32).
dtype (optional) – Torch device / dtype (default cpu / float32).
rng (int or None) – Integer seed for
torch.manual_seed.log (bool) – If True, log/checkpoint to
train_dir.train_dir (str, optional) – Checkpoint directory (used when
log=True).
- Return type: