Skip to content

Change dtype of prior parameter of sts.Seasonal from float32 to float64 #2003

@CleverEskimo

Description

@CleverEskimo

Hi all,

I have question based on dtype change in sts.Seasonal, namely I have created multiple of objects like:

Parameter: local_linear_trend/_slope_scale
Prior: tfp.distributions.LogNormal("slope_scale_prior", batch_shape=[], event_shape=[], dtype=float64)
--------------------------------------------------------------------------------------------------------------------------------------------
Parameter: month_of_year/_drift_scale
Prior: tfp.distributions.LogNormal("LogNormal", batch_shape=[], event_shape=[], dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------
Parameter: day_of_week/_drift_scale
Prior: tfp.distributions.LogNormal("LogNormal", batch_shape=[], event_shape=[], dtype=float64)

As you can see in the case of month_of_year/_drift_scale Prior dtype is float32. It looks like it is not allowed during training as there is an exception:

ValueError: ConstrainedSeasonalStateSpaceModel, type=<dtype: 'float32'>, must be of the same type (<dtype: 'float64'>) as LocalLinearTrendStateSpaceModel.

Is there any solution to change dtype from float32 to float64?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions