dsipts.models.d3vae.diffusion_process module

Authors:

Li,Yan (liyan22021121@gmail.com)

dsipts.models.d3vae.diffusion_process.get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps)[source]
dsipts.models.d3vae.diffusion_process.default(val, d)[source]
dsipts.models.d3vae.diffusion_process.extract(a, t, x_shape)[source]
dsipts.models.d3vae.diffusion_process.noise_like(shape, device, repeat=False)[source]
class dsipts.models.d3vae.diffusion_process.GaussianDiffusion(bvae, input_size, beta_start=0, beta_end=0.1, diff_steps=100, loss_type='l2', betas=None, scale=0.1, beta_schedule='linear')[source]

Bases: Module

Params:

bave: The bidirectional vae model. beta_start: The start value of the beta schedule. beta_end: The end value of the beta schedule. beta_schedule: the kind of the beta schedule, here are fixed to linear, you can adjust it as needed. diff_steps: The maximum diffusion steps. scale: scale parameters for the target time series.

__init__(bvae, input_size, beta_start=0, beta_end=0.1, diff_steps=100, loss_type='l2', betas=None, scale=0.1, beta_schedule='linear')[source]
Params:

bave: The bidirectional vae model. beta_start: The start value of the beta schedule. beta_end: The end value of the beta schedule. beta_schedule: the kind of the beta schedule, here are fixed to linear, you can adjust it as needed. diff_steps: The maximum diffusion steps. scale: scale parameters for the target time series.

q_sample(x_start, t, noise=None)[source]
Diffuse the initial input.
param x_start:

[B, T, *]

return:

[B, T, *]

q_sample_target(y_target, t, noise=None)[source]
Diffuse the target.
param y_target:

[B1, T1, *]

return:

(tensor) [B1, T1, *]

p_losses(x_start, y_target, t, noise=None, noise1=None)[source]

Put the diffused input into the BVAE to generate the output. Params

param x_start:

[B, T, *]

param y_target:

[B1, T1, *]

param t:

[B,]

return y_noisy:

diffused target.

return total_c:

the total correlations of latent variables in BVAE.

return all_z:

all latent variables of BVAE.

log_prob(x_input, y_target, time)[source]