{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "27342e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gpt as g"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b01c9af2",
   "metadata": {},
   "source": [
    "## Lecture 5: A first look at the machine learning module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "436fabb5",
   "metadata": {},
   "source": [
    "We will create two dense layers of 12 neurons, placed on a $4^3$ grid with nearest-neighbor interactions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6400be05",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPT :       2.074188 s : Initializing gpt.random(test,vectorized_ranlux24_389_64) took 0.00079298 s\n"
     ]
    }
   ],
   "source": [
    "grid = g.grid([4, 4, 4], g.double)\n",
    "rng = g.random(\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "21811a47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test layers based on 12 densely connected neurons per layer that live on a nearest-neighbor 4^3 grid\n",
    "# for now real weights only\n",
    "n_dense = 12\n",
    "n_depth = 2\n",
    "n_training = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4a8aab51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPT :     113.933457 s : Cost: 412.39705036936016\n",
      "GPT :     114.062864 s : gradient_descent: iteration 0: f(x) = 3.364174161809646e+02, |df|/sqrt(dof) = 4.623279e-02, step = 0.5\n",
      "GPT :     115.204867 s : gradient_descent: iteration 10: f(x) = 1.940783390620649e+02, |df|/sqrt(dof) = 2.401684e-02, step = 0.5\n",
      "GPT :     116.441977 s : gradient_descent: iteration 20: f(x) = 1.583727855257709e+02, |df|/sqrt(dof) = 2.159283e-02, step = 0.5\n",
      "GPT :     117.630582 s : gradient_descent: iteration 30: f(x) = 1.422574106843941e+02, |df|/sqrt(dof) = 2.097117e-02, step = 0.5\n",
      "GPT :     118.686322 s : gradient_descent: NOT converged in 40 iterations;  |df|/sqrt(dof) = 1.953533e-02 / 1.000000e-07\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# data type of input layer\n",
    "ot_i = g.ot_vector_real_additive_group(n_dense)\n",
    "\n",
    "# data type of weights\n",
    "ot_w = g.ot_matrix_real_additive_group(n_dense)\n",
    "\n",
    "n = g.ml.model.sequence([g.ml.layer.nearest_neighbor(grid, ot_i, ot_w)] * n_depth)\n",
    "\n",
    "W = n.random_weights(rng)\n",
    "\n",
    "training_input = [rng.uniform_real(g.lattice(grid, ot_i)) for i in range(n_training)]\n",
    "training_output = [rng.uniform_real(g.lattice(grid, ot_i)) for i in range(n_training)]\n",
    "\n",
    "c = n.cost(training_input, training_output)\n",
    "g.message(\"Cost:\", c(W))\n",
    "\n",
    "ls0 = g.algorithms.optimize.line_search_none\n",
    "# ls2 = g.algorithms.optimize.line_search_quadratic\n",
    "# pr = g.algorithms.optimize.polak_ribiere\n",
    "# opt = g.algorithms.optimize.non_linear_cg(\n",
    "#     maxiter=40, eps=1e-7, step=1e-1, line_search=ls2, beta=pr\n",
    "# )\n",
    "\n",
    "opt = g.algorithms.optimize.gradient_descent(\n",
    "    maxiter=40, eps=1e-7, step=0.5, line_search=ls0\n",
    ")\n",
    "\n",
    "# Train network\n",
    "opt(c)(W, W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a09afbf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
