{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Sampling from and decoding an HMM\n\nThis script shows how to sample points from a Hidden Markov Model (HMM):\nwe use a 4-state model with specified mean and covariance.\n\nThe plot shows the sequence of observations generated with the transitions\nbetween them. We can see that, as specified by our transition matrix,\nthere are no transition between component 1 and 3.\n\nThen, we decode our model to recover the input parameters.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport matplotlib.pyplot as plt\n\nfrom hmmlearn import hmm\n\n# Prepare parameters for a 4-components HMM\n# Initial population probability\nstartprob = np.array([0.6, 0.3, 0.1, 0.0])\n# The transition matrix, note that there are no transitions possible\n# between component 1 and 3\ntransmat = np.array([[0.7, 0.2, 0.0, 0.1],\n                     [0.3, 0.5, 0.2, 0.0],\n                     [0.0, 0.3, 0.5, 0.2],\n                     [0.2, 0.0, 0.2, 0.6]])\n# The means of each component\nmeans = np.array([[0.0, 0.0],\n                  [0.0, 11.0],\n                  [9.0, 10.0],\n                  [11.0, -1.0]])\n# The covariance of each component\ncovars = .5 * np.tile(np.identity(2), (4, 1, 1))\n\n# Build an HMM instance and set parameters\ngen_model = hmm.GaussianHMM(n_components=4, covariance_type=\"full\")\n\n# Instead of fitting it from the data, we directly set the estimated\n# parameters, the means and covariance of the components\ngen_model.startprob_ = startprob\ngen_model.transmat_ = transmat\ngen_model.means_ = means\ngen_model.covars_ = covars\n\n# Generate samples\nX, Z = gen_model.sample(500)\n\n# Plot the sampled data\nfig, ax = plt.subplots()\nax.plot(X[:, 0], X[:, 1], \".-\", label=\"observations\", ms=6,\n        mfc=\"orange\", alpha=0.7)\n\n# Indicate the component numbers\nfor i, m in enumerate(means):\n    ax.text(m[0], m[1], 'Component %i' % (i + 1),\n            size=17, horizontalalignment='center',\n            bbox=dict(alpha=.7, facecolor='w'))\nax.legend(loc='best')\nfig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, let's ensure we can recover our parameters.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "scores = list()\nmodels = list()\nfor n_components in (3, 4, 5):\n    for idx in range(10):\n        # define our hidden Markov model\n        model = hmm.GaussianHMM(n_components=n_components,\n                                covariance_type='full',\n                                random_state=idx)\n        model.fit(X[:X.shape[0] // 2])  # 50/50 train/validate\n        models.append(model)\n        scores.append(model.score(X[X.shape[0] // 2:]))\n        print(f'Converged: {model.monitor_.converged}'\n              f'\\tScore: {scores[-1]}')\n\n# get the best model\nmodel = models[np.argmax(scores)]\nn_states = model.n_components\nprint(f'The best model had a score of {max(scores)} and {n_states} '\n      'states')\n\n# use the Viterbi algorithm to predict the most likely sequence of states\n# given the model\nstates = model.predict(X)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's plot our states compared to those generated and our transition matrix\nto get a sense of our model. We can see that the recovered states follow\nthe same path as the generated states, just with the identities of the\nstates transposed (i.e. instead of following a square as in the first\nfigure, the nodes are switch around but this does not change the basic\npattern). The same is true for the transition matrix.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# plot model states over time\nfig, ax = plt.subplots()\nax.plot(Z, states)\nax.set_title('States compared to generated')\nax.set_xlabel('Generated State')\nax.set_ylabel('Recovered State')\nfig.show()\n\n# plot the transition matrix\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 5))\nax1.imshow(gen_model.transmat_, aspect='auto', cmap='spring')\nax1.set_title('Generated Transition Matrix')\nax2.imshow(model.transmat_, aspect='auto', cmap='spring')\nax2.set_title('Recovered Transition Matrix')\nfor ax in (ax1, ax2):\n    ax.set_xlabel('State To')\n    ax.set_ylabel('State From')\n\nfig.tight_layout()\nfig.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}