pyFDN.train_fdn#

pyFDN.train_fdn(model, mode, *, target=None, criteria=None, sparsity_alpha=0.2, mss_nfft=(256, 512, 1024), max_steps=2000, lr=0.001, optimizer='adam', patience=10, tol=1e-06, device=None, dtype=None, rng=None, log=False, train_dir=None)[source]#

Train model for mode in place and return a TrainLog.

Read the trained result back with pyFDN.extract_build().

Parameters:
  • model (flamo Shell) – A trainable model from pyFDN.build_fdn() / trainable_from_build.

  • mode (str) – "colorless", "match_spectrogram" or "match_mel_spectrogram". "colorless" is single-input/single-output only.

  • target (np.ndarray, optional) – Reference impulse response for the matching modes (unused for colorless). Shape (n_samples,) or (n_samples, n_out), or a 3-D (n_samples, n_out, n_in) IR matrix to fit a full MIMO system.

  • criteria (list of (criterion, alpha, requires_model), optional) – Replace the default loss list (primary loss + sparsity) with your own.

  • sparsity_alpha (float) – Weight of the feedback-matrix sparsity penalty (default 0.2; 0 disables).

  • mss_nfft (tuple of int) – STFT window sizes for the spectrogram modes.

  • max_steps (max gradient steps, learning rate, plateau patience.)

  • lr (max gradient steps, learning rate, plateau patience.)

  • patience (max gradient steps, learning rate, plateau patience.)

  • optimizer (str) – "adam" (default) or "lbfgs".

  • tol (float) – Relative-improvement threshold for the plateau early stop.

  • device (optional) – Torch device / dtype (default cpu / float32).

  • dtype (optional) – Torch device / dtype (default cpu / float32).

  • rng (int or None) – Integer seed for torch.manual_seed.

  • log (bool) – If True, log/checkpoint to train_dir.

  • train_dir (str, optional) – Checkpoint directory (used when log=True).

Return type:

TrainLog