{ "nbformat_minor": 0, "nbformat": 4, "cells": [ { "execution_count": null, "cell_type": "code", "source": [ "%matplotlib inline" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "\n==============================================\nReal-time feedback for decoding :: Server Side\n==============================================\n\nThis example demonstrates how to setup a real-time feedback\nmechanism using StimServer and StimClient.\n\nThe idea here is to display future stimuli for the class which\nis predicted less accurately. This allows on-demand adaptation\nof the stimuli depending on the needs of the classifier.\n\nTo run this example, open ipython in two separate terminals.\nIn the first, run rt_feedback_server.py and then wait for the\nmessage\n\n RtServer: Start\n\nOnce that appears, run rt_feedback_client.py in the other terminal\nand the feedback script should start.\n\nAll brain responses are simulated from a fiff file to make it easy\nto test. However, it should be possible to adapt this script\nfor a real experiment.\n\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "# Author: Mainak Jas \n#\n# License: BSD (3-clause)\n\nimport time\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom sklearn import preprocessing\nfrom sklearn.svm import SVC\nfrom sklearn.pipeline import Pipeline\nfrom sklearn.cross_validation import train_test_split\nfrom sklearn.metrics import confusion_matrix\n\nimport mne\nfrom mne.datasets import sample\nfrom mne.realtime import StimServer\nfrom mne.realtime import MockRtClient\nfrom mne.decoding import Vectorizer, FilterEstimator\n\nprint(__doc__)\n\n# Load fiff file to simulate data\ndata_path = sample.data_path()\nraw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'\nraw = mne.io.read_raw_fif(raw_fname, preload=True)\n\n# Instantiating stimulation server\n\n# The with statement is necessary to ensure a clean exit\nwith StimServer(port=4218) as stim_server:\n\n # The channels to be used while decoding\n picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,\n stim=True, exclude=raw.info['bads'])\n\n rt_client = MockRtClient(raw)\n\n # Constructing the pipeline for classification\n filt = FilterEstimator(raw.info, 1, 40)\n scaler = preprocessing.StandardScaler()\n vectorizer = Vectorizer()\n clf = SVC(C=1, kernel='linear')\n\n concat_classifier = Pipeline([('filter', filt), ('vector', vectorizer),\n ('scaler', scaler), ('svm', clf)])\n\n stim_server.start(verbose=True)\n\n # Just some initially decided events to be simulated\n # Rest will decided on the fly\n ev_list = [4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4]\n\n score_c1, score_c2, score_x = [], [], []\n\n for ii in range(50):\n\n # Tell the stim_client about the next stimuli\n stim_server.add_trigger(ev_list[ii])\n\n # Collecting data\n if ii == 0:\n X = rt_client.get_event_data(event_id=ev_list[ii], tmin=-0.2,\n tmax=0.5, picks=picks,\n stim_channel='STI 014')[None, ...]\n y = ev_list[ii]\n else:\n X_temp = rt_client.get_event_data(event_id=ev_list[ii], tmin=-0.2,\n tmax=0.5, picks=picks,\n stim_channel='STI 014')\n X_temp = X_temp[np.newaxis, ...]\n\n X = np.concatenate((X, X_temp), axis=0)\n\n time.sleep(1) # simulating the isi\n y = np.append(y, ev_list[ii])\n\n # Start decoding after collecting sufficient data\n if ii >= 10:\n # Now start doing rtfeedback\n X_train, X_test, y_train, y_test = train_test_split(X, y,\n test_size=0.2,\n random_state=7)\n\n y_pred = concat_classifier.fit(X_train, y_train).predict(X_test)\n\n cm = confusion_matrix(y_test, y_pred)\n\n score_c1.append(float(cm[0, 0]) / sum(cm, 1)[0] * 100)\n score_c2.append(float(cm[1, 1]) / sum(cm, 1)[1] * 100)\n\n # do something if one class is decoded better than the other\n if score_c1[-1] < score_c2[-1]:\n print(\"We decoded class RV better than class LV\")\n ev_list.append(3) # adding more LV to future simulated data\n else:\n print(\"We decoded class LV better than class RV\")\n ev_list.append(4) # adding more RV to future simulated data\n\n # Clear the figure\n plt.clf()\n\n # The x-axis for the plot\n score_x.append(ii)\n\n # Now plot the accuracy\n plt.plot(score_x[-5:], score_c1[-5:])\n plt.hold(True)\n plt.plot(score_x[-5:], score_c2[-5:])\n plt.xlabel('Trials')\n plt.ylabel('Classification score (% correct)')\n plt.title('Real-time feedback')\n plt.ylim([0, 100])\n plt.xticks(score_x[-5:])\n plt.legend(('LV', 'RV'), loc='upper left')\n plt.show()" ], "outputs": [], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 2", "name": "python2", "language": "python" }, "language_info": { "mimetype": "text/x-python", "nbconvert_exporter": "python", "name": "python", "file_extension": ".py", "version": "2.7.13", "pygments_lexer": "ipython2", "codemirror_mode": { "version": 2, "name": "ipython" } } } }