{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Scikit-Learn GAIuS™ Pipeline Example" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pprint\n", "\n", "from ia.gaius.experimental.sklearn import GAIuSClassifier, GDFTransformer\n", "from ia.gaius.manager import AgentManager\n", "\n", "from sklearn.datasets import fetch_openml\n", "from sklearn.feature_selection import VarianceThreshold\n", "from sklearn.metrics import classification_report\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import KBinsDiscretizer, StandardScaler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fetch mnist data in openml format. Each row corresponds to a single MNIST Image" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "mnist = fetch_openml('mnist_784', version=1, parser='auto')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X = mnist.data\n", "y = mnist.target\n", "\n", "feature_names = mnist.feature_names" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "X_train,X_test = X[:60000], X[60000:]\n", "y_train,y_test = y[:60000], y[60000:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Clear all agents on system" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "am = AgentManager()\n", "am.kill_all_agents()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define pipeline to:\n", "\n", " - center and scale MNIST data,\n", " - eliminate features with low variance\n", " - bin the data into integer bins\n", " - Convert to GDF sequence\n", " - Ingest into Cognitive Processor" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "gaius_pipeline = Pipeline([('scaler', StandardScaler()),\n", " ('variance_threshold', VarianceThreshold(0.005)),\n", " ('discretizer', KBinsDiscretizer(32, encode='ordinal')),\n", " ('gdfer', GDFTransformer(as_vector=True)),\n", " ('cp_classifier', GAIuSClassifier(recall_threshold=0.1, max_predictions=5, near_vector_count=3, pred_as_int=False))])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'P1': {'AUTOLEARN': False,\n", " 'HYPOTHESIZED': False,\n", " 'PREDICT': True,\n", " 'SLEEPING': False,\n", " 'SNAPSHOT': False,\n", " 'emotives': {},\n", " 'last_learned_model_name': '',\n", " 'models_kb': '{KB| objects: 0}',\n", " 'name': 'P1',\n", " 'num_observe_call': 0,\n", " 'size_WM': 0,\n", " 'target': '',\n", " 'time': 0,\n", " 'vector_dimensionality': -1,\n", " 'vectors_kb': '{KB| objects: 0}'}}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaius_pipeline.steps[-1][-1].agent.show_status()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "# Ignore all user warnings\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "982ac3efc68a495a80eb9f5b59c6ba39", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10000 [00:00#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}
Pipeline(steps=[('scaler', StandardScaler()),\n",
       "                ('variance_threshold', VarianceThreshold(threshold=0.005)),\n",
       "                ('discretizer', KBinsDiscretizer(encode='ordinal', n_bins=32)),\n",
       "                ('gdfer', GDFTransformer(as_vector=True)),\n",
       "                ('cp_classifier',\n",
       "                 GAIuSClassifier(max_predictions=5, near_vector_count=3,\n",
       "                                 pred_as_int=False))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('scaler', StandardScaler()),\n", " ('variance_threshold', VarianceThreshold(threshold=0.005)),\n", " ('discretizer', KBinsDiscretizer(encode='ordinal', n_bins=32)),\n", " ('gdfer', GDFTransformer(as_vector=True)),\n", " ('cp_classifier',\n", " GAIuSClassifier(max_predictions=5, near_vector_count=3,\n", " pred_as_int=False))])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaius_pipeline.fit(X_train[:10000], y_train[:10000])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0c2fe4c32af34251966defe1dc3c917a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1000 [00:00