<?xml version="1.0" encoding="UTF-8"?>
<rss  xmlns:atom="http://www.w3.org/2005/Atom" 
      xmlns:media="http://search.yahoo.com/mrss/" 
      xmlns:content="http://purl.org/rss/1.0/modules/content/" 
      xmlns:dc="http://purl.org/dc/elements/1.1/" 
      version="2.0">
<channel>
<title>Home</title>
<link>https://shonczinner.github.io/</link>
<atom:link href="https://shonczinner.github.io/index.xml" rel="self" type="application/rss+xml"/>
<description>Shon Czinner&#39;s blog about AI, statistics, quantitative finance, economics, and more</description>
<generator>quarto-1.9.37</generator>
<lastBuildDate>Mon, 04 May 2026 00:00:00 GMT</lastBuildDate>
<item>
  <title>The 90-year-old idea behind JEPA models: Canonical Correlation Analysis (CCA)</title>
  <dc:creator>Shon Czinner</dc:creator>
  <link>https://shonczinner.github.io/posts/embedding-prediction/</link>
  <description><![CDATA[ 




<section id="introduction" class="level1">
<h1>Introduction</h1>
<blockquote class="blockquote">
<p>Concepts of correlation and regression may be applied not only to ordinary one-dimensional variates but also to variates of two or more dimensions.</p>
</blockquote>
<p>This is the first sentence from the paper “Relations Between Two Sets of Variates” <span class="citation" data-cites="hotelling1936">(Hotelling 1936)</span> by statistician and economist Harold Hotelling. This paper introduced Canonical Correlation Analysis (CCA). In modern terminology, “CCA is used to find a common signal among two large matrices” <span class="citation" data-cites="bykhovskaya2025">(Bykhovskaya and Gorin 2025)</span>.</p>
<p>In JEPA, the objective is the same except the second data matrix happens to be simply a different view of the same data in the first dataset (e.g.&nbsp;via data augmentation or spatial or temporal proximity). One of the recent papers to acknowledge a connection states, “JEPA-based models implicitly perform a non-linear generalization of Canonical Correlation Analysis”. <span class="citation" data-cites="huang2026">(Huang 2026)</span></p>
<p>CCA’s connection to JEPA is relevant to Schmidhuber’s debate on <a href="https://people.idsia.ch/~juergen/who-invented-jepa.html">who invented JEPA</a>, which is directed at Yann LeCun. Personally, I think Hotelling deserves the credit for the idea of maximizing correlation in embedding space.</p>
<p>Of course, the CCA model has many differences from JEPA.</p>
<p>For one, CCA does not enforce a shared encoder. But the biggest difference is that CCA is linear. Non-linear neural variants of CCA have been researched with the earliest usage of the term “Deep CCA” being <span class="citation" data-cites="andrew2013">(Andrew et al. 2013)</span>.</p>
<p>Connecting JEPA models back to its CCA roots is genuinely useful. Another Deep CCA paper <span class="citation" data-cites="benton2017">(Benton et al. 2017)</span> relaxed the assumption of two sets of variables to an arbitrary number based on a generalization of CCA proposed in 1961 <span class="citation" data-cites="horst1961">(Horst 1961)</span>. Conceivably, JEPAs could be expanded to handle more than two views as well.</p>
</section>
<section id="cca-vs.-jepa-overview" class="level1">
<h1>CCA vs.&nbsp;JEPA Overview</h1>
<section id="cca" class="level2">
<h2 class="anchored" data-anchor-id="cca">CCA</h2>
<p>Suppose we have zero-mean matrices <img src="https://latex.codecogs.com/png.latex?X=(x_1,...,x_n)%5ET%5Cin%20%5Cmathbb%20R%5E%7Bn%5Ctimes%20d_x%7D"> and <img src="https://latex.codecogs.com/png.latex?Y=(y_1,...,y_n)%5ET%5Cin%5Cmathbb%20R%5E%7Bn%5Ctimes%20d_y%7D">.</p>
<p>Let <img src="https://latex.codecogs.com/png.latex?k%5Cleq%20%5Cmin(d_x,d_y,%20n)"> and <img src="https://latex.codecogs.com/png.latex?A%5Cin%20%20%5Cmathbb%20R%5E%7Bd_x%5Ctimes%20k%7D"> and <img src="https://latex.codecogs.com/png.latex?B%5Cin%20%20%5Cmathbb%20R%5E%7Bd_y%5Ctimes%20k%7D"> so that <img src="https://latex.codecogs.com/png.latex?XA=z_x%5Cin%5Cmathbb%20R%5E%7Bn%20%5Ctimes%20k%7D"> and <img src="https://latex.codecogs.com/png.latex?YB=z_y%5Cin%5Cmathbb%20R%5E%7Bn%20%5Ctimes%20k%7D">.</p>
<p>CCA solves the following maximization problem,</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cmax_%7BA,B%7D%20%5Ctext%7Btr%7D%5Cleft(%5Cfrac%7B1%7D%7Bn%7Dz_x%5ETz_y%5Cright)%20"> <img src="https://latex.codecogs.com/png.latex?%5Ctext%7Bs.t%7D"> <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B1%7D%7Bn%7Dz_x%5ETz_x=%5Cfrac%7B1%7D%7Bn%7Dz_y%5ETz_y=I"></p>
<p>This maximizes the trace of the cross-correlation matrix, while constraining embedding vectors to unit variance and zero covariance.</p>
<p>Similar to the equivalence between maximizing variance and minimizing prediction error in solving PCA, we have a relationship between the trace of the cross-correlation matrix and embedding prediction error,</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi=1%7D%5En%20%7C%7Cz_x%5E%7B(i)%7D-z_y%5E%7B(i)%7D%7C%7C%5E2=%5Cfrac%7B1%7D%7Bn%7D%7C%7Cz_x-z_y%7C%7C_F%5E2=%20%5Cfrac%7B1%7D%7Bn%7D%5Ctext%7Btr%7D(z_x%5ETz_x)%20+%20%5Cfrac%7B1%7D%7Bn%7D%5Ctext%7Btr%7D(z_y%5ETz_y)%20-%20%5Cfrac%7B2%7D%7Bn%7D%5Ctext%7Btr%7D(z_x%5ETz_y)"> And due to the whitening constraints, <img src="https://latex.codecogs.com/png.latex?=2k-%20%5Cfrac%7B2%7D%7Bn%7D%5Ctext%7Btr%7D(z_x%5ETz_y)"></p>
<p>So maximizing the trace of the cross-correlation under the whitening constraints is equivalent to minimizing the MSE of the embedding representations. Therefore we can write CCA as,</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cmin_%7BA,B%7D%20%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bi=1%7D%5En%20%7C%7Cz_x%5E%7B(i)%7D-z_y%5E%7B(i)%7D%7C%7C%5E2"> <img src="https://latex.codecogs.com/png.latex?%5Ctext%7Bs.t%7D"> <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B1%7D%7Bn%7Dz_x%5ETz_x=%5Cfrac%7B1%7D%7Bn%7Dz_y%5ETz_y=I"></p>
</section>
<section id="jepa" class="level2">
<h2 class="anchored" data-anchor-id="jepa">JEPA</h2>
<p>Adopting the previous notation, JEPA is constrained to <img src="https://latex.codecogs.com/png.latex?d_x=d_y=d"> as a result of the joint-embedding. In JEPA, we have the encoder <img src="https://latex.codecogs.com/png.latex?f_%5Ctheta:%5Cmathbb%20R%5E%7Bd%7D%5Crightarrow%20%5Cmathbb%20R%5Ek">, and predictor <img src="https://latex.codecogs.com/png.latex?g_%5Cvarphi:%5Cmathbb%20R%5E%7Bk%7D%5Crightarrow%20%5Cmathbb%20R%5Ek">.</p>
<p>Let <img src="https://latex.codecogs.com/png.latex?z_x%5E%7B(i)%7D=g_%5Cvarphi(f_%5Ctheta(x_i))">, <img src="https://latex.codecogs.com/png.latex?z_y%5E%7B(i)%7D=f_%5Ctheta(y_i)">.</p>
<p>Then we solve,</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cmin_%7B%5Ctheta,%5Cvarphi%7D%5Cfrac%7B1%7D%7Bn%7D%20%5Csum_%7Bi=1%7D%5En%20%7C%7Cz_x%5E%7B(i)%7D-z_y%5E%7B(i)%7D%7C%7C%5E2"></p>
<p>Note the similarity in the objective function but the lack of whitening constraints. The lack of whitening constraints results in representational and dimensional collapse. For example, a trivial solution to the above problem is <img src="https://latex.codecogs.com/png.latex?z_x%5E%7B(i)%7D=z_y%5E%7B(i)%7D=c">.</p>
<p>As discussed in my <a href="../../posts/sigreg-sketched-isotropic-gaussian-regularization/">previous blog post</a> SIGReg <span class="citation" data-cites="balestriero2025">(Balestriero and LeCun 2025)</span> fixes this problem. What does it do? It encourages the embeddings <img src="https://latex.codecogs.com/png.latex?z_x"> and <img src="https://latex.codecogs.com/png.latex?z_y"> to have an isotropic (i.e.&nbsp;unit variance, uncorrelated) Gaussian distribution. As a result it encourages,</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B1%7D%7Bn%7Dz_x%5ETz_x=%5Cfrac%7B1%7D%7Bn%7Dz_y%5ETz_y=I"></p>
</section>
</section>
<section id="conclusion" class="level1">
<h1>Conclusion</h1>
<p>As I mentioned in the introduction, Schmidhuber has debated <a href="https://people.idsia.ch/~juergen/who-invented-jepa.html">who invented JEPA</a> and said this about LeCun,</p>
<blockquote class="blockquote">
<p>Dr.&nbsp;LeCun’s heavily promoted Joint Embedding Predictive Architecture (JEPA) is the heart of his new company. However, the core ideas are not original to LeCun. Instead, JEPA is essentially identical to our 1992 Predictability Maximization system.</p>
</blockquote>
<p>Schmidhuber references Yann LeCun’s response,</p>
<blockquote class="blockquote">
<p>JEPA is merely a name for a general concept. The question is, and has always been, how do you make it work (particularly how do you prevent it from collapsing), and how do you make it work at scale with SOTA results on non-toy problems. That’s the hard part. Ideas are a dime a dozen. Making them work is what the community will give you credit for.</p>
</blockquote>
<p>Do I agree with LeCun? Yes and no.</p>
<p>Yes, because of course you will get credit for making things work, and ideas are indeed arguably “a dime a dozen”.</p>
<p>No, because the thread of citations is important for progress. If important citations are missed, whether intentionally or not, the correct thing to do is just add them. We’re all only the better for doing so. The connection that JEPA models have to CCA is informative.</p>
<p>My opinion is that JEPA/Predictability Maximization models are architectural enhancements layered on top of CCA. Non-linearity is an enhancement.</p>
<p>Ultimately, these models all have the same objective function introduced by CCA: find the transformations that result in maximal correlation between sets of multidimensional data.</p>



</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-bibliography"><h2 class="anchored quarto-appendix-heading">References</h2><div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-andrew2013" class="csl-entry">
Andrew, Galen, Raman Arora, Jeff Bilmes, and Karen Livescu. 2013. <span>“Deep Canonical Correlation Analysis.”</span> <em>International Conference on Machine Learning</em>, 1247–55. <a href="https://proceedings.mlr.press/v28/andrew13.html">https://proceedings.mlr.press/v28/andrew13.html</a>.
</div>
<div id="ref-balestriero2025" class="csl-entry">
Balestriero, Randall, and Yann LeCun. 2025. <em>LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics</em>. <a href="https://arxiv.org/abs/2511.08544">https://arxiv.org/abs/2511.08544</a>.
</div>
<div id="ref-benton2017" class="csl-entry">
Benton, Adrian, Huda Khayrallah, Biman Gujral, Dee Ann Reisinger, Sheng Zhang, and Raman Arora. 2017. <em>Deep Generalized Canonical Correlation Analysis</em>. <a href="https://arxiv.org/abs/1702.02519">https://arxiv.org/abs/1702.02519</a>.
</div>
<div id="ref-bykhovskaya2025" class="csl-entry">
Bykhovskaya, Anna, and Vadim Gorin. 2025. <em>Canonical Correlation Analysis: Review</em>. <a href="https://arxiv.org/abs/2411.15625">https://arxiv.org/abs/2411.15625</a>.
</div>
<div id="ref-horst1961" class="csl-entry">
Horst, Paul. 1961. <em>Generalized Canonical Correlations and Their Application to Experimental Data</em>. Journal of clinical psychology.
</div>
<div id="ref-hotelling1936" class="csl-entry">
Hotelling, Harold. 1936. <span>“Relations Between Two Sets of Variates.”</span> <em>Biometrika</em> 28 (3/4): 321–77. <a href="http://www.jstor.org/stable/2333955">http://www.jstor.org/stable/2333955</a>.
</div>
<div id="ref-huang2026" class="csl-entry">
Huang, Yongchao. 2026. <em>VJEPA: Variational Joint Embedding Predictive Architectures as Probabilistic World Models</em>. <a href="https://arxiv.org/abs/2601.14354">https://arxiv.org/abs/2601.14354</a>.
</div>
</div></section></div> ]]></description>
  <category>ai</category>
  <category>jepa</category>
  <guid>https://shonczinner.github.io/posts/embedding-prediction/</guid>
  <pubDate>Mon, 04 May 2026 00:00:00 GMT</pubDate>
  <media:content url="https://shonczinner.github.io/posts/embedding-prediction/noisy_matrix_with_marginals.png" medium="image" type="image/png" height="146" width="144"/>
</item>
<item>
  <title>A Small JEPA Word Embedding Model</title>
  <dc:creator>Shon Czinner</dc:creator>
  <link>https://shonczinner.github.io/posts/small-jepa-language-model/</link>
  <description><![CDATA[ 




<p>After my <a href="../../posts/sigreg-sketched-isotropic-gaussian-regularization/">prior blog post about SIGReg</a>, I figured I’d train a small Joint-Embedding Predictive Architecture (JEPA) model to demonstrate it.</p>
<p>The paper “LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels” <span class="citation" data-cites="maes2026">(Maes et al. 2026)</span> suggested significantly reducing the complexity of JEPA models by removing stop-gradients, and the exponential-moving-average encoder. This was in the context of world models and planning.</p>
<p>In this case, I’m applying JEPA to the task of creating word embeddings. Prior methodologies include Word2vec <span class="citation" data-cites="mikolov2013">(Mikolov et al. 2013)</span> which uses a log-linear model and negative sampling, MLP next-word prediction <span class="citation" data-cites="bengio2000">(Bengio et al. 2000)</span>, applying CCA to small context windows <span class="citation" data-cites="dhillon2011">(Dhillon et al. 2011)</span>, and training an autoencoder on small context windows <span class="citation" data-cites="shao2025">(Shao et al. 2025)</span>.</p>
<p>We’ll be training a linear JEPA model with SIGReg on a small shakespeare dataset to show that it learns some informative embeddings. In other words, we’ll train an encoder that turns two words into word embeddings, and train a linear predictor that predicts the second word embedding from the first. It would be easy to extend this methodology to non-linear encoders and use larger contexts than single words.</p>
<section id="overview" class="level1">
<h1>Overview</h1>
<p>First we’ll take our dataset and convert it into tokens. For example,</p>
<pre><code>["to", "be", ",", "or", "not", "to", "be"] -&gt; [1, 2, 3, 4, 5, 1, 2]</code></pre>
<p>Then we create the dataset where we have context/target pairs. So in this case that would look like,</p>
<pre><code>Context 1: [1], Target 1: [2]
Context 2: [2], Target 2: [3]
...
Context 5: [5], Target 5: [1]
Context 6: [1], Target 6: [2]</code></pre>
<p>Then we create the JEPA model which uses the same embedding <img src="https://latex.codecogs.com/png.latex?f_%5Ctheta(%5Ccdot)"> for the context <img src="https://latex.codecogs.com/png.latex?x"> and target <img src="https://latex.codecogs.com/png.latex?y">, and then has predictor <img src="https://latex.codecogs.com/png.latex?g_%5Cvarphi(%5Ccdot)"> predict the target from the context. More formally,</p>
<p><img src="https://latex.codecogs.com/png.latex?%0Af_%5Ctheta(x)=h_x%0A"></p>
<p><img src="https://latex.codecogs.com/png.latex?%0Af_%5Ctheta(y)=h_y%0A"></p>
<p><img src="https://latex.codecogs.com/png.latex?%0Ag_%5Cvarphi(h_x)=%5Chat%20h_y%0A"></p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BL%7D_%7BJEPA%7D(%5Ctheta,%5Cvarphi)=MSE(h_y,%5Chat%20h_y)+%5Clambda%5Ctext%7BSIGReg%7D(h_x)%0A"></p>
<div id="be13a26c" class="cell" data-execution_count="50">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pandas <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> pd</span>
<span id="cb3-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb3-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb3-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb3-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> requests</span>
<span id="cb3-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> re</span></code></pre></div></div>
</details>
</div>
</section>
<section id="preparing-the-data" class="level1">
<h1>Preparing The Data</h1>
<p>The dataset is a text file containing some of Shakespeare’s plays.</p>
<div id="4ff7a761" class="cell" data-execution_count="51">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1">txt_url <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"https://www.gutenberg.org/cache/epub/100/pg100.txt"</span></span>
<span id="cb4-2">response <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> requests.get(txt_url)</span>
<span id="cb4-3">txt <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> response.text</span>
<span id="cb4-4">txt[:<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>]</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display" data-execution_count="51">
<pre><code>'\ufeffThe Project Gutenberg eBook of The Complete Works of William Shakespeare\r\n    \r\nThis eBook is for t'</code></pre>
</div>
</div>
<p>To turn this into our dataset, we’ll convert everything to lower-case, split out punctuation, and then split on spaces to get our tokens. We’ll treat everything with frequency below 5 as an unknown token. The dataset consists of a single context word and the target word is simply the next word.</p>
<div id="4f4866bd" class="cell" data-execution_count="52">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Lowercase then put spaces around punctuation and \n and then split on spaces</span></span>
<span id="cb6-2">tokens <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> re.findall(<span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r"</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">\w</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">|</span><span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">[^</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">\w\s</span><span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">]</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>, txt.lower(), re.UNICODE)</span>
<span id="cb6-3"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(tokens[:<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>])  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># tokens[:10]</span></span>
<span id="cb6-4"></span>
<span id="cb6-5">min_freq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">25</span></span>
<span id="cb6-6">vocab_freq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.Series(tokens).value_counts()</span>
<span id="cb6-7">vocab <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> vocab_freq[vocab_freq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> min_freq].index.tolist()</span>
<span id="cb6-8"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Vocab size: "</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(vocab))</span>
<span id="cb6-9"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"First 10 vocab tokens: "</span>, vocab[:<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>])</span>
<span id="cb6-10"></span>
<span id="cb6-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># add &lt;unk&gt; token for out-of-vocab words</span></span>
<span id="cb6-12">vocab.insert(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"&lt;unk&gt;"</span>)</span>
<span id="cb6-13">token_to_id <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {token: idx <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> idx, token <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(vocab)}</span>
<span id="cb6-14">id_to_token <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {idx: token <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> idx, token <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(vocab)}</span>
<span id="cb6-15"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> encode(tokens):</span>
<span id="cb6-16">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> [token_to_id.get(token, token_to_id[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"&lt;unk&gt;"</span>]) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> token <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> tokens]</span>
<span id="cb6-17"></span>
<span id="cb6-18"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> decode(token_ids):</span>
<span id="cb6-19">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> [id_to_token.get(token_id, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"&lt;unk&gt;"</span>) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> token_id <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> token_ids]  </span>
<span id="cb6-20"></span>
<span id="cb6-21">encoded <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> encode(tokens)</span>
<span id="cb6-22"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(encoded[:<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">15</span>]) </span>
<span id="cb6-23"></span>
<span id="cb6-24">x0 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> encoded[:<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]</span>
<span id="cb6-25">x1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> encoded[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>:]</span>
<span id="cb6-26"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Dataset size: "</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(x0))</span>
<span id="cb6-27">pd.DataFrame({<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"x0"</span>: x0, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"x1"</span>: x1}).head()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>['\ufeff', 'the', 'project', 'gutenberg', 'ebook', 'of', 'the', 'complete', 'works', 'of']
Vocab size:  3227
First 10 vocab tokens:  [',', '.', 'the', 'and', '’', 'i', 'to', 'of', 'a', 'you']
[0, 3, 1071, 1098, 0, 8, 3, 0, 1523, 8, 1174, 0, 27, 0, 16]
Dataset size:  1262243</code></pre>
</div>
<div class="cell-output cell-output-display" data-execution_count="52">
<div>


<table class="dataframe caption-top table table-sm table-striped small" data-border="1">
<thead>
<tr class="header">
<th data-quarto-table-cell-role="th"></th>
<th data-quarto-table-cell-role="th">x0</th>
<th data-quarto-table-cell-role="th">x1</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<th data-quarto-table-cell-role="th">0</th>
<td>0</td>
<td>3</td>
</tr>
<tr class="even">
<th data-quarto-table-cell-role="th">1</th>
<td>3</td>
<td>1071</td>
</tr>
<tr class="odd">
<th data-quarto-table-cell-role="th">2</th>
<td>1071</td>
<td>1098</td>
</tr>
<tr class="even">
<th data-quarto-table-cell-role="th">3</th>
<td>1098</td>
<td>0</td>
</tr>
<tr class="odd">
<th data-quarto-table-cell-role="th">4</th>
<td>0</td>
<td>8</td>
</tr>
</tbody>
</table>

</div>
</div>
</div>
</section>
<section id="sigreg" class="level1">
<h1>SIGReg</h1>
<p>We use the same SIGReg code as in my <a href="../../posts/sigreg-sketched-isotropic-gaussian-regularization/">prior blog post</a>. This is what makes the embedding space a bit Gaussian, avoiding dimensional collapse, as you’ll see later in Figure&nbsp;1 which plots the first two embedding dimensions against each other.</p>
<div id="1c4016f1" class="cell" data-execution_count="53">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> SIGReg(x, num_slices<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, k<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">17</span>):</span>
<span id="cb8-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># x: (N, D) samples</span></span>
<span id="cb8-3">    N, D <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.shape</span>
<span id="cb8-4">    device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.device</span>
<span id="cb8-5"></span>
<span id="cb8-6">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- Projection directions ---</span></span>
<span id="cb8-7">    A <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(D, num_slices, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb8-8">    A <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/=</span> A.norm(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># normalize columns → unit directions</span></span>
<span id="cb8-9"></span>
<span id="cb8-10">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Project to 1D: shape → (N, num_slices)</span></span>
<span id="cb8-11">    X_proj <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> A</span>
<span id="cb8-12"></span>
<span id="cb8-13">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- Integration points ---</span></span>
<span id="cb8-14">    t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.linspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, k, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (k,)</span></span>
<span id="cb8-15">    phi_normal <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.exp(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> t<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)          <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (k,)</span></span>
<span id="cb8-16">    weight <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> phi_normal                          <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Gaussian window</span></span>
<span id="cb8-17"></span>
<span id="cb8-18">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Broadcast shapes: (N, M, 1) ⋅ (1, 1, k)</span></span>
<span id="cb8-19">    X_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X_proj.unsqueeze(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> t</span>
<span id="cb8-20"></span>
<span id="cb8-21">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Empirical characteristic function across samples</span></span>
<span id="cb8-22">    ecf <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.exp(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">1j</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> X_t).mean(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (M, k)</span></span>
<span id="cb8-23"></span>
<span id="cb8-24">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Squared difference</span></span>
<span id="cb8-25">    diff_sq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (ecf <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> phi_normal).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">abs</span>()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (M, k)</span></span>
<span id="cb8-26"></span>
<span id="cb8-27">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Weighted integration for all projections → shape (M,)</span></span>
<span id="cb8-28">    per_direction_T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.trapz(diff_sq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> weight, t, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> N</span>
<span id="cb8-29"></span>
<span id="cb8-30">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># GLOBAL aggregation — MEAN instead of MAX</span></span>
<span id="cb8-31">    T_global <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> per_direction_T.mean()</span>
<span id="cb8-32"></span>
<span id="cb8-33">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> T_global</span></code></pre></div></div>
</details>
</div>
</section>
<section id="making-and-training-the-embedding-model" class="level1">
<h1>Making and Training The Embedding Model</h1>
<p>Now we’re ready to train a model. The encoder in this case is just an Embedding module and the predictor is just a Linear module. We use MSE loss comparing the predicted next word embedding versus the actual next word embedding as the objective function.</p>
<div id="3369ceed" class="cell" data-execution_count="54">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1">device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cuda"</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.cuda.is_available() <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cpu"</span></span>
<span id="cb9-2"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Using device: "</span>, device)</span>
<span id="cb9-3"></span>
<span id="cb9-4">x0 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tensor(x0, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">long</span>).to(device)</span>
<span id="cb9-5">x1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tensor(x1, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">long</span>).to(device)</span>
<span id="cb9-6"></span>
<span id="cb9-7">embedding_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span></span>
<span id="cb9-8">encoder <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.nn.Embedding(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(vocab), embedding_dim).to(device)</span>
<span id="cb9-9">next_encoding_predictor <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.nn.Linear(embedding_dim, embedding_dim).to(device)</span>
<span id="cb9-10">sigreg_lambda <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span></span>
<span id="cb9-11"></span>
<span id="cb9-12">n_epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span></span>
<span id="cb9-13">batch_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2048</span></span>
<span id="cb9-14">optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.optim.Adam(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(encoder.parameters()) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(next_encoding_predictor.parameters()), lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>)</span>
<span id="cb9-15">loss_fn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.nn.MSELoss()</span>
<span id="cb9-16"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> epoch <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(n_epochs):</span>
<span id="cb9-17">    total_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb9-18">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(x0), batch_size):</span>
<span id="cb9-19">        x0_batch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x0[i:i<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span>batch_size]</span>
<span id="cb9-20">        x1_batch <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x1[i:i<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span>batch_size]</span>
<span id="cb9-21">        </span>
<span id="cb9-22">        x0_embedded <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> encoder(x0_batch)</span>
<span id="cb9-23">        x1_embedded <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> encoder(x1_batch)</span>
<span id="cb9-24">        </span>
<span id="cb9-25">        x1_predicted <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> next_encoding_predictor(x0_embedded)</span>
<span id="cb9-26">        </span>
<span id="cb9-27">        loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loss_fn(x1_predicted, x1_embedded)</span>
<span id="cb9-28">        sigreg_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> SIGReg(x0_embedded)</span>
<span id="cb9-29">        loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> sigreg_lambda<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>sigreg_loss</span>
<span id="cb9-30">        optimizer.zero_grad()</span>
<span id="cb9-31">        loss.backward()</span>
<span id="cb9-32">        optimizer.step()</span>
<span id="cb9-33">        </span>
<span id="cb9-34">        total_loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> loss.item()</span>
<span id="cb9-35">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> (epoch<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> (n_epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>:</span>
<span id="cb9-36">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Epoch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epoch<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>n_epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>total_loss<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(x0)<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Using device:  cuda
Epoch 1/10, Loss: 0.00048023610461530306
Epoch 2/10, Loss: 0.0004523864227085027
Epoch 3/10, Loss: 0.0004391360668865251
Epoch 4/10, Loss: 0.0004303184247303124
Epoch 5/10, Loss: 0.00042514875212748543
Epoch 6/10, Loss: 0.00042165358612022394
Epoch 7/10, Loss: 0.0004195104785958759
Epoch 8/10, Loss: 0.0004182606700457514
Epoch 9/10, Loss: 0.0004172710015018738
Epoch 10/10, Loss: 0.00041632126264852835</code></pre>
</div>
</div>
</section>
<section id="visualize" class="level1">
<h1>Visualize</h1>
<div id="cell-fig-embspace" class="cell" data-execution_count="55">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> visualize_embeddings(encoder, vocab, token_to_id, max_tokens<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">200</span>):</span>
<span id="cb11-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb11-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Visualize token embeddings.</span></span>
<span id="cb11-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    </span></span>
<span id="cb11-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    If embedding_dim == 2 → plot directly.</span></span>
<span id="cb11-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    If embedding_dim &gt; 2 → plot first 2 dimensions</span></span>
<span id="cb11-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    </span></span>
<span id="cb11-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    max_tokens limits how many tokens to plot for readability.</span></span>
<span id="cb11-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    """</span></span>
<span id="cb11-10"></span>
<span id="cb11-11">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># optionally limit tokens for clarity</span></span>
<span id="cb11-12">    tokens <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> vocab[:max_tokens]</span>
<span id="cb11-13">    indices <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [token_to_id[t] <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> t <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> tokens]</span>
<span id="cb11-14">    emb_subset <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> encoder(torch.tensor(indices).to(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(encoder.parameters()).device)).detach().cpu().numpy()</span>
<span id="cb11-15">    </span>
<span id="cb11-16"></span>
<span id="cb11-17">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># plot</span></span>
<span id="cb11-18">    plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>))</span>
<span id="cb11-19">    plt.scatter(emb_subset[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], emb_subset[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb11-20">    </span>
<span id="cb11-21">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># annotate tokens</span></span>
<span id="cb11-22">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, token <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(tokens):</span>
<span id="cb11-23">        plt.annotate(token, (emb_subset[i, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], emb_subset[i, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]), fontsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>)</span>
<span id="cb11-24">    </span>
<span id="cb11-25">    plt.title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Token Embeddings Visualization"</span>)</span>
<span id="cb11-26">    plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Dim 1"</span>)</span>
<span id="cb11-27">    plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Dim 2"</span>)</span>
<span id="cb11-28">    plt.grid()</span>
<span id="cb11-29">    plt.show()</span>
<span id="cb11-30"></span>
<span id="cb11-31">visualize_embeddings(encoder, vocab, token_to_id)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div id="fig-embspace" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-embspace-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://shonczinner.github.io/posts/small-jepa-language-model/index_files/figure-html/fig-embspace-output-1.png" id="fig-embspace" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig quarto-uncaptioned" id="fig-embspace-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1
</figcaption>
</figure>
</div>
</div>
</div>
<p>The visualization above shows only the first two dimensions of the embedding space. We can see several clusters including character names, royal titles, and tokens that follow apostrophes in words like ne’er, ’tis and o’er. This shows that the JEPA model is learning informative embeddings.</p>
</section>
<section id="further-embedding-investigation" class="level1">
<h1>Further Embedding Investigation</h1>
<p>We can also observe what words are closest in embedding space.</p>
<div id="d21a8295" class="cell" data-execution_count="61">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> neighbour_table_l2(words, n<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>):</span>
<span id="cb12-2">    device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">next</span>(encoder.parameters()).device</span>
<span id="cb12-3">    encoder.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb12-4"></span>
<span id="cb12-5">    vocab_ids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.arange(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(vocab)).to(device)</span>
<span id="cb12-6"></span>
<span id="cb12-7">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb12-8">        vocab_emb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> encoder(vocab_ids)</span>
<span id="cb12-9"></span>
<span id="cb12-10">    rows <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb12-11"></span>
<span id="cb12-12">    word_order <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {w: i <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, w <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(words)}</span>
<span id="cb12-13"></span>
<span id="cb12-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> word <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> words:</span>
<span id="cb12-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> word <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">not</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> token_to_id:</span>
<span id="cb12-16">            <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">continue</span></span>
<span id="cb12-17"></span>
<span id="cb12-18">        word_id <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> token_to_id[word]</span>
<span id="cb12-19"></span>
<span id="cb12-20">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb12-21">            query_emb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> encoder(torch.tensor([word_id]).to(device))</span>
<span id="cb12-22"></span>
<span id="cb12-23">            x_sq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (query_emb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, keepdim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb12-24">            v_sq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (vocab_emb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>).unsqueeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb12-25">            cross <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.matmul(query_emb, vocab_emb.T)</span>
<span id="cb12-26"></span>
<span id="cb12-27">            distances <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (x_sq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> v_sq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> cross).squeeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb12-28"></span>
<span id="cb12-29">            distances[word_id] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"inf"</span>)</span>
<span id="cb12-30"></span>
<span id="cb12-31">            top_ids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.topk(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>distances, n).indices.tolist()</span>
<span id="cb12-32"></span>
<span id="cb12-33">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> rank, i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(top_ids, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>):</span>
<span id="cb12-34">            rows.append({</span>
<span id="cb12-35">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"query"</span>: word,</span>
<span id="cb12-36">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"query_order"</span>: word_order[word],</span>
<span id="cb12-37">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"rank"</span>: rank,</span>
<span id="cb12-38">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"token"</span>: id_to_token[i],</span>
<span id="cb12-39">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"l2_distance"</span>: distances[i].item()</span>
<span id="cb12-40">            })</span>
<span id="cb12-41"></span>
<span id="cb12-42">    df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.DataFrame(rows)</span>
<span id="cb12-43"></span>
<span id="cb12-44">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># enforce deterministic ordering for display</span></span>
<span id="cb12-45">    df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> df.sort_values([<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"query_order"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"rank"</span>])</span>
<span id="cb12-46"></span>
<span id="cb12-47">    pivot_tokens <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (</span>
<span id="cb12-48">        df.pivot(index<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"query"</span>, columns<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"rank"</span>, values<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"token"</span>)</span>
<span id="cb12-49">        .reindex(words)</span>
<span id="cb12-50">    ).T</span>
<span id="cb12-51"></span>
<span id="cb12-52">    display(pivot_tokens)</span>
<span id="cb12-53"></span>
<span id="cb12-54">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> df</span>
<span id="cb12-55"></span>
<span id="cb12-56"></span>
<span id="cb12-57"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ---- run ----</span></span>
<span id="cb12-58">words <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"young"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"king"</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"romeo"</span>]</span>
<span id="cb12-59">df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> neighbour_table_l2(words, n<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>


<table class="dataframe caption-top table table-sm table-striped small" data-border="1">
<thead>
<tr class="header">
<th data-quarto-table-cell-role="th">query</th>
<th data-quarto-table-cell-role="th">young</th>
<th data-quarto-table-cell-role="th">king</th>
<th data-quarto-table-cell-role="th">romeo</th>
</tr>
<tr class="even">
<th data-quarto-table-cell-role="th">rank</th>
<th data-quarto-table-cell-role="th"></th>
<th data-quarto-table-cell-role="th"></th>
<th data-quarto-table-cell-role="th"></th>
</tr>
</thead>
<tbody>
<tr class="odd">
<th data-quarto-table-cell-role="th">1</th>
<td>old</td>
<td>friar</td>
<td>malcolm</td>
</tr>
<tr class="even">
<th data-quarto-table-cell-role="th">2</th>
<td>delicate</td>
<td>chief</td>
<td>hamlet</td>
</tr>
<tr class="odd">
<th data-quarto-table-cell-role="th">3</th>
<td>civil</td>
<td>tamora</td>
<td>wolsey</td>
</tr>
<tr class="even">
<th data-quarto-table-cell-role="th">4</th>
<td>honourable</td>
<td>taking</td>
<td>lucius</td>
</tr>
<tr class="odd">
<th data-quarto-table-cell-role="th">5</th>
<td>troubled</td>
<td>perfect</td>
<td>viola</td>
</tr>
</tbody>
</table>

</div>
</div>
</div>
<p>It’s encouraging that “king” is near other professions, young is near other adjectives (including its opposite - old), and “romeo” is near other names.</p>
</section>
<section id="future-directions" class="level1">
<h1>Future Directions</h1>
<p>As I mentioned earlier, it would be easy to extend this methodology to non-linear (e.g.&nbsp;MLP, CNN, RNN, Transformer) models and use larger contexts and targets than single words. It’s also possible to play around with other hyperparameters like SIGReg regularizer coefficient, embedding dimension, and hidden dimension and try larger datasets.</p>
<p>Compared to many prior methods for getting word embeddings, this does appear to be less complicated than things like negative sampling (e.g.&nbsp;word2vec).</p>
<p>There’s also recent work on approximations to SIGReg that are likely more computationally efficient with very little downside <span class="citation" data-cites="akbar2026">(Akbar 2026)</span>.</p>



</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-bibliography"><h2 class="anchored quarto-appendix-heading">References</h2><div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-akbar2026" class="csl-entry">
Akbar, Habibullah. 2026. <em>Weak-SIGReg: Covariance Regularization for Stable Deep Learning</em>. <a href="https://arxiv.org/abs/2603.05924">https://arxiv.org/abs/2603.05924</a>.
</div>
<div id="ref-bengio2000" class="csl-entry">
Bengio, Yoshua, Réjean Ducharme, and Pascal Vincent. 2000. <span>“A Neural Probabilistic Language Model.”</span> In <em>Advances in Neural Information Processing Systems</em>, edited by T. Leen, T. Dietterich, and V. Tresp, vol. 13. MIT Press. <a href="https://proceedings.neurips.cc/paper_files/paper/2000/file/728f206c2a01bf572b5940d7d9a8fa4c-Paper.pdf">https://proceedings.neurips.cc/paper_files/paper/2000/file/728f206c2a01bf572b5940d7d9a8fa4c-Paper.pdf</a>.
</div>
<div id="ref-dhillon2011" class="csl-entry">
Dhillon, Paramveer, Dean P Foster, and Lyle Ungar. 2011. <span>“Multi-View Learning of Word Embeddings via CCA.”</span> In <em>Advances in Neural Information Processing Systems</em>, edited by J. Shawe-Taylor, R. Zemel, P. Bartlett, F. Pereira, and K. Weinberger, vol. 24. Curran Associates, Inc. <a href="https://proceedings.neurips.cc/paper_files/paper/2011/file/6c4b761a28b734fe93831e3fb400ce87-Paper.pdf">https://proceedings.neurips.cc/paper_files/paper/2011/file/6c4b761a28b734fe93831e3fb400ce87-Paper.pdf</a>.
</div>
<div id="ref-maes2026" class="csl-entry">
Maes, Lucas, Quentin Le Lidec, Damien Scieur, Yann LeCun, and Randall Balestriero. 2026. <em>LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels</em>. <a href="https://arxiv.org/abs/2603.19312">https://arxiv.org/abs/2603.19312</a>.
</div>
<div id="ref-mikolov2013" class="csl-entry">
Mikolov, Tomas, Kai Chen, Greg Corrado, and Jeffrey Dean. 2013. <em>Efficient Estimation of Word Representations in Vector Space</em>. <a href="https://arxiv.org/abs/1301.3781">https://arxiv.org/abs/1301.3781</a>.
</div>
<div id="ref-shao2025" class="csl-entry">
Shao, Chenze, Darren Li, Fandong Meng, and Jie Zhou. 2025. <em>Continuous Autoregressive Language Models</em>. <a href="https://arxiv.org/abs/2510.27688">https://arxiv.org/abs/2510.27688</a>.
</div>
</div></section></div> ]]></description>
  <category>ai</category>
  <category>code</category>
  <category>jepa</category>
  <guid>https://shonczinner.github.io/posts/small-jepa-language-model/</guid>
  <pubDate>Thu, 30 Apr 2026 00:00:00 GMT</pubDate>
</item>
<item>
  <title>Sketched Isotropic Gaussian Regularization (SIGReg) Explained</title>
  <dc:creator>Shon Czinner</dc:creator>
  <link>https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/</link>
  <description><![CDATA[ 




<p>The paper “LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics”<span class="citation" data-cites="balestriero2025">(Balestriero and LeCun 2025)</span> proposed an interesting method of regularizing latent spaces that I thought I’d explore.</p>
<p>This is a regularization technique to make latent embeddings have isotropic Gaussian distribution - each dimension is encouraged to be uncorrelated and independent, and gaussian distributed. They argue that gaussian latent embeddings are optimal because they are unbiased and lower variance for downstream tasks.</p>
<p>They make latent embeddings have Gaussian distribution using one-dimensional tests of normality in random directions via characteristic functions.</p>
<section id="one-dimension" class="level1">
<h1>One-Dimension</h1>
<section id="ecdf-tests" class="level2">
<h2 class="anchored" data-anchor-id="ecdf-tests">ECDF Tests</h2>
<p>First the paper motivates determining how Gaussian a distribution is with its empirical (i.e.&nbsp;observed) cumulative distribution function (ecdf). Lets look at this in one-dimension.</p>
<p>First we generate two “one sample” univariate datasets - one uniform and one standard gaussian. In this case the target distribution is the standard gaussian. Then we just compare the ecdfs with the theoretical gaussian cdf.</p>
<div id="d347bc82" class="cell" data-execution_count="3">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> scipy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> stats</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb1-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span></code></pre></div></div>
</details>
</div>
<div id="cell-fig-ecdfs" class="cell" data-execution_count="4">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Sample size</span></span>
<span id="cb2-2">n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span></span>
<span id="cb2-3"></span>
<span id="cb2-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Generate samples</span></span>
<span id="cb2-5">X_uniform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.rand(n)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># uniform(-3,3)</span></span>
<span id="cb2-6">X_gaussian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.randn(n)      <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Standard Normal</span></span>
<span id="cb2-7"></span>
<span id="cb2-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Define points for theoretical gaussian CDF</span></span>
<span id="cb2-9">x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.linspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span>)</span>
<span id="cb2-10">gaussian_cdf <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> stats.norm.cdf(x)</span>
<span id="cb2-11"></span>
<span id="cb2-12"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ECDF function</span></span>
<span id="cb2-13"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> ecdf(data):</span>
<span id="cb2-14">    x_sorted <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.sort(data)</span>
<span id="cb2-15">    y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.arange(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(x_sorted)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(x_sorted)</span>
<span id="cb2-16">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> x_sorted, y</span>
<span id="cb2-17"></span>
<span id="cb2-18"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Compute ECDFs</span></span>
<span id="cb2-19">x_ecdf_uniform, y_ecdf_uniform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ecdf(X_uniform)</span>
<span id="cb2-20">x_ecdf_gaussian, y_ecdf_gaussian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ecdf(X_gaussian)</span>
<span id="cb2-21"></span>
<span id="cb2-22"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plotting using a loop but keeping samples separate</span></span>
<span id="cb2-23">plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>))</span>
<span id="cb2-24"></span>
<span id="cb2-25">plots <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [</span>
<span id="cb2-26">    (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Uniform Sample ECDF"</span>, x_ecdf_uniform, y_ecdf_uniform),</span>
<span id="cb2-27">    (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Gaussian Sample ECDF"</span>, x_ecdf_gaussian, y_ecdf_gaussian)</span>
<span id="cb2-28">]</span>
<span id="cb2-29"></span>
<span id="cb2-30"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, (title, x_ecdf, y_ecdf) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(plots, start<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>):</span>
<span id="cb2-31">    plt.subplot(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, i)</span>
<span id="cb2-32">    plt.plot(x_ecdf, y_ecdf, marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'.'</span>, linestyle<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'none'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r'ECDF </span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">F_n</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">(</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">x</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">)</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb2-33">    plt.plot(x, gaussian_cdf, color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'red'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r'Gaussian CDF </span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">F</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">(</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">x</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">)</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb2-34">    plt.title(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>title<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> vs Gaussian CDF'</span>)</span>
<span id="cb2-35">    plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'x'</span>)</span>
<span id="cb2-36">    plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'CDF'</span>)</span>
<span id="cb2-37">    plt.legend()</span>
<span id="cb2-38">    plt.grid(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb2-39"></span>
<span id="cb2-40">plt.tight_layout()</span>
<span id="cb2-41">plt.show()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div id="fig-ecdfs" class="quarto-float quarto-figure quarto-figure-center anchored">
<figure class="quarto-float quarto-float-fig figure">
<div aria-describedby="fig-ecdfs-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
<img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/fig-ecdfs-output-1.png" id="fig-ecdfs" class="img-fluid figure-img">
</div>
<figcaption class="quarto-float-caption-bottom quarto-float-caption quarto-float-fig quarto-uncaptioned" id="fig-ecdfs-caption-0ceaefa1-69ba-4598-a22c-09a6ac19f8ca">
Figure&nbsp;1
</figcaption>
</figure>
</div>
</div>
</div>
<p>As you’d expect, in Figure&nbsp;1 the ecdf of the uniform distribution sample deviates from the gaussian cdf having large gaps, but the gaussian sample does not.</p>
<p>We can quantify the “gap” or how close the empirical cdf is to the desired theoretical with the following formula,</p>
<p><img src="https://latex.codecogs.com/png.latex?T=%5Cint_%7B-%5Cinfty%7D%5E%5Cinfty%20(F_n(x)-F(x))%5E2%20w(x)%20dF(x)"></p>
<p>where <img src="https://latex.codecogs.com/png.latex?F_n(x)"> is the empirical cdf, <img src="https://latex.codecogs.com/png.latex?F(x)"> is the theoretical cdf, and <img src="https://latex.codecogs.com/png.latex?w(x)"> is a weighting function. With <img src="https://latex.codecogs.com/png.latex?w(x)=1"> this is known the cramer-von-mise test statistic and <img src="https://latex.codecogs.com/png.latex?w(x)=(F(x)(1-F(x)))%5E%7B-1%7D"> is known as the anderson darling test statistic. We can compute this in scipy.</p>
<p>With some algebraic manipulation there are actually closed-form formulas for these test statistics. However, using the ecdf is still computationally expensive, and non-differentiable because it requires sorting the datasets. We calculate the cramer von mise and anderson darling test statistics below.</p>
<div id="5a77b522" class="cell" data-execution_count="5">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Tests</span></span>
<span id="cb3-2">cvm_1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> stats.cramervonmises(X_uniform, cdf<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>stats.norm.cdf)</span>
<span id="cb3-3">cvm_2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> stats.cramervonmises(X_gaussian, cdf<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>stats.norm.cdf)</span>
<span id="cb3-4"></span>
<span id="cb3-5">ad_1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> stats.anderson(X_uniform, dist<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'norm'</span>, method<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'interpolate'</span>)</span>
<span id="cb3-6">ad_2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> stats.anderson(X_gaussian, dist<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'norm'</span>, method<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'interpolate'</span>)</span>
<span id="cb3-7"></span>
<span id="cb3-8">alpha <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.05</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># significance level</span></span>
<span id="cb3-9"></span>
<span id="cb3-10"></span>
<span id="cb3-11"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> cvm_report(name, cvm):</span>
<span id="cb3-12">    decision <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"REJECT H0 (not Gaussian)"</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> cvm.pvalue <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> alpha <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"FAIL TO REJECT H0 (Gaussian-consistent)"</span></span>
<span id="cb3-13">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>name<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">:"</span>)</span>
<span id="cb3-14">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"  statistic (T) = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>cvm<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>statistic<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb3-15">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"  p-value   = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>cvm<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>pvalue<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4g}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb3-16">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"  decision  = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>decision<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb3-17"></span>
<span id="cb3-18"></span>
<span id="cb3-19"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> ad_report(name, ad):</span>
<span id="cb3-20">    decision <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"REJECT H0 (not Gaussian)"</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> ad.pvalue <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> alpha <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"FAIL TO REJECT H0 (Gaussian-consistent)"</span></span>
<span id="cb3-21">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>name<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">:"</span>)</span>
<span id="cb3-22">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"  statistic (T) = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>ad<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>statistic<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb3-23">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"  p-value   = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>ad<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>pvalue<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4g}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb3-24">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"  decision  = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>decision<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb3-25"></span>
<span id="cb3-26"></span>
<span id="cb3-27"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Cramér–von Mises Test Results"</span>)</span>
<span id="cb3-28"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"-----------------------------"</span>)</span>
<span id="cb3-29">cvm_report(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Uniform sample"</span>, cvm_1)</span>
<span id="cb3-30">cvm_report(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Gaussian sample"</span>, cvm_2)</span>
<span id="cb3-31"></span>
<span id="cb3-32"></span>
<span id="cb3-33"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Anderson–Darling Test Results"</span>)</span>
<span id="cb3-34"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"-----------------------------"</span>)</span>
<span id="cb3-35">ad_report(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Uniform sample"</span>, ad_1)</span>
<span id="cb3-36">ad_report(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Gaussian sample"</span>, ad_2)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Cramér–von Mises Test Results
-----------------------------
Uniform sample:
  statistic (T) = 1.7954
  p-value   = 3.264e-05
  decision  = REJECT H0 (not Gaussian)

Gaussian sample:
  statistic (T) = 0.3512
  p-value   = 0.09743
  decision  = FAIL TO REJECT H0 (Gaussian-consistent)

Anderson–Darling Test Results
-----------------------------
Uniform sample:
  statistic (T) = 1.6875
  p-value   = 0.01
  decision  = REJECT H0 (not Gaussian)

Gaussian sample:
  statistic (T) = 0.2015
  p-value   = 0.15
  decision  = FAIL TO REJECT H0 (Gaussian-consistent)
</code></pre>
</div>
</div>
<p>As expected, the gaussian hypothesis is rejected for the uniform sample and not rejected for the gaussian sample.</p>
<p>Note: scipy.stats.anderson with dist set to “norm” tests for normality not standard normality and it does so by subtracting the sample mean and dividing by the sample standard deviation of your dataset.</p>
</section>
<section id="characteristic-functions" class="level2">
<h2 class="anchored" data-anchor-id="characteristic-functions">Characteristic Functions</h2>
<p>Although a distribution is uniquely defined by its cdf, it is also uniquely defined by its characteristic function. The characteristic function (CF) of a random variable <img src="https://latex.codecogs.com/png.latex?X"> is defined as,</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cvarphi_X(t)=E(e%5E%7B-itX%7D)"></p>
<p>where <img src="https://latex.codecogs.com/png.latex?i"> is the imaginary unit. For a sample we have empirical characteristic function (ECF),</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Chat%5Cvarphi_n(t)=%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7Bj=1%7D%5En%20e%5E%7B-itx_j%7D"></p>
<p>Reminiscent of how the gap in cdf was defined, for CFs we can quantify the gap using,</p>
<p><img src="https://latex.codecogs.com/png.latex?T=%5Cint_%7B-%5Cinfty%7D%5E%5Cinfty%20(%5Chat%5Cvarphi_n(t)-%5Cvarphi(t))%5E2%20w(t)%20dt"></p>
<p>which is known as the Epps-Pulley test. The weight function is typically <img src="https://latex.codecogs.com/png.latex?w(t)=e%5E%7B-t%5E2/2%7D"> which is also the theoretical CF for the Gaussian distribution.</p>
<p>The advantage of the Epps-Pulley test is that it’s differentiable and continuous i.e.&nbsp;if I move one data point a small amount the Epps-Pulley test statistic changes proportionally. In contrast, moving a data point would change the jumps in the ecdf which makes ecdf tests non-differentiable.</p>
<p>This integral can be estimated through the basic trapezoidal Riemann sum.</p>
<div id="c6f54dbd" class="cell" data-execution_count="6">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> eps_pulley(X, k, plot<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>, title<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">""</span>):</span>
<span id="cb5-2">    t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.linspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, k)</span>
<span id="cb5-3"></span>
<span id="cb5-4">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># theoretical characteristic function of N(0,1)</span></span>
<span id="cb5-5">    phi_gauss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.exp(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> t<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb5-6"></span>
<span id="cb5-7">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># empirical characteristic function</span></span>
<span id="cb5-8">    X_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.outer(X, t)</span>
<span id="cb5-9">    exp_itx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.exp(<span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">1j</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> X_t)   <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># switched to +1j</span></span>
<span id="cb5-10">    phi_emp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.mean(exp_itx, axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb5-11"></span>
<span id="cb5-12">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># integrated squared error</span></span>
<span id="cb5-13">    err <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">abs</span>(phi_emp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> phi_gauss)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span></span>
<span id="cb5-14">    T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.trapezoid(err, t)</span>
<span id="cb5-15"></span>
<span id="cb5-16">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> plot:</span>
<span id="cb5-17">        plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>))</span>
<span id="cb5-18"></span>
<span id="cb5-19">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># plot individual exp(itx) curves (real parts)</span></span>
<span id="cb5-20">        <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(X)):</span>
<span id="cb5-21">            plt.plot(</span>
<span id="cb5-22">                t,</span>
<span id="cb5-23">                exp_itx[i].real,</span>
<span id="cb5-24">                color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"lightgrey"</span>,</span>
<span id="cb5-25">                alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>,</span>
<span id="cb5-26">                linewidth<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.8</span></span>
<span id="cb5-27">            )</span>
<span id="cb5-28"></span>
<span id="cb5-29">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># empirical mean</span></span>
<span id="cb5-30">        plt.plot(</span>
<span id="cb5-31">            t,</span>
<span id="cb5-32">            phi_emp.real,</span>
<span id="cb5-33">            color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"dimgray"</span>,</span>
<span id="cb5-34">            linewidth<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb5-35">            label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r"Empirical mean </span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">Re</span><span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">[E[e^{itX}]</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">]=Re</span><span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">[</span><span class="er" style="color: #AD0000;
background-color: null;
font-style: inherit;">\</span><span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">hat</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\v</span><span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">arphi_n(t)]</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb5-36">        )</span>
<span id="cb5-37"></span>
<span id="cb5-38">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># theoretical Gaussian CF</span></span>
<span id="cb5-39">        plt.plot(</span>
<span id="cb5-40">            t,</span>
<span id="cb5-41">            phi_gauss,</span>
<span id="cb5-42">            color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"black"</span>,</span>
<span id="cb5-43">            linestyle<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"--"</span>,</span>
<span id="cb5-44">            linewidth<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb5-45">            label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r"Gaussian CF </span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">e</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">^</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">{-t</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">^</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">2/2}=</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\v</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">arphi</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">(</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">t</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">)</span><span class="er" style="color: #AD0000;
background-color: null;
font-style: inherit;">)</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb5-46">        )</span>
<span id="cb5-47"></span>
<span id="cb5-48">        plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"t"</span>)</span>
<span id="cb5-49">        plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Real part"</span>)</span>
<span id="cb5-50">        plt.title(title)</span>
<span id="cb5-51">        plt.legend()</span>
<span id="cb5-52">        plt.grid(alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>)</span>
<span id="cb5-53">        plt.show()</span>
<span id="cb5-54"></span>
<span id="cb5-55">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> T</span>
<span id="cb5-56"></span>
<span id="cb5-57"></span>
<span id="cb5-58">k <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span></span>
<span id="cb5-59"></span>
<span id="cb5-60">T_gaussian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eps_pulley(</span>
<span id="cb5-61">    X_gaussian, k,</span>
<span id="cb5-62">    plot<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb5-63">    title<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Gaussian Sample"</span></span>
<span id="cb5-64">)</span>
<span id="cb5-65"></span>
<span id="cb5-66">T_uniform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> eps_pulley(</span>
<span id="cb5-67">    X_uniform, k,</span>
<span id="cb5-68">    plot<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb5-69">    title<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Uniform Sample"</span></span>
<span id="cb5-70">)</span>
<span id="cb5-71"></span>
<span id="cb5-72"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Epps–Pulley Test Results"</span>)</span>
<span id="cb5-73"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"-------------------------"</span>)</span>
<span id="cb5-74"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Gaussian sample: T = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>T_gaussian<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.6f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb5-75"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Uniform sample:  T = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>T_uniform<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.6f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb5-76"></span>
<span id="cb5-77"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> T_gaussian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> T_uniform:</span>
<span id="cb5-78">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">Conclusion: Gaussian sample is closer to the standard normal reference distribution."</span>)</span>
<span id="cb5-79"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span>:</span>
<span id="cb5-80">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\n</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">Conclusion: Uniform sample appears closer to the standard normal reference distribution."</span>)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-5-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-5-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-stdout">
<pre><code>Epps–Pulley Test Results
-------------------------
Gaussian sample: T = 0.064450
Uniform sample:  T = 0.924244

Conclusion: Gaussian sample is closer to the standard normal reference distribution.</code></pre>
</div>
</div>
<p>The lower T statistic for the gaussian sample indicates less “gap”. I haven’t looked into the limiting distribution of the Epps-Pulley Test Statistic that would be used to calculate p-values.</p>
<p>We can also test the robustness to the number of divisions in the trapezoidal Riemann sum and see it converges quite quickly.</p>
<div id="58c3509a" class="cell" data-execution_count="7">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1">ks <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [x<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> x <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)]</span>
<span id="cb7-2">T_gaussian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [eps_pulley(X_gaussian, k) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> k <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> ks]</span>
<span id="cb7-3">T_uniform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [eps_pulley(X_uniform, k) <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> k <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> ks]</span>
<span id="cb7-4"></span>
<span id="cb7-5">plots <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [</span>
<span id="cb7-6">    (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"uniform Sample"</span>, T_uniform),</span>
<span id="cb7-7">    (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"gaussian Sample"</span>, T_gaussian)</span>
<span id="cb7-8">]</span>
<span id="cb7-9"></span>
<span id="cb7-10"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> i, (title, T) <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(plots, start<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>):</span>
<span id="cb7-11">    plt.plot(ks, T, marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'.'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'ECDF'</span>)</span>
<span id="cb7-12">    plt.title(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>title<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> Epps-Pulley Test'</span>)</span>
<span id="cb7-13">    plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'k'</span>)</span>
<span id="cb7-14">    plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'T'</span>)</span>
<span id="cb7-15">    plt.legend()</span>
<span id="cb7-16">    plt.grid(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb7-17">    plt.show()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-6-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-6-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="gradient-descent" class="level2">
<h2 class="anchored" data-anchor-id="gradient-descent">Gradient Descent</h2>
<p>Now we know how to quantify the gap between data and a target distribution with Epps-Pulley. Lets make a neural network output data that fits the desired target distribution. First lets send 1D data through a random, untrained neural network and check its output distribution.</p>
<div id="dc796af9" class="cell" data-execution_count="8">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> plot_output_distribution(model, X, bins<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span>):</span>
<span id="cb8-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb8-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Plot model output histogram against standard normal PDF.</span></span>
<span id="cb8-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    </span></span>
<span id="cb8-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Args:</span></span>
<span id="cb8-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        model: PyTorch model</span></span>
<span id="cb8-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        X: input tensor</span></span>
<span id="cb8-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        bins: histogram bins</span></span>
<span id="cb8-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    """</span></span>
<span id="cb8-10">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb8-11"></span>
<span id="cb8-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb8-13">        Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(X).squeeze().cpu().numpy()</span>
<span id="cb8-14"></span>
<span id="cb8-15">    plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>))</span>
<span id="cb8-16"></span>
<span id="cb8-17">    plt.hist(</span>
<span id="cb8-18">        Y,</span>
<span id="cb8-19">        bins<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>bins,</span>
<span id="cb8-20">        density<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb8-21">        alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.6</span>,</span>
<span id="cb8-22">        label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Network outputs"</span></span>
<span id="cb8-23">    )</span>
<span id="cb8-24"></span>
<span id="cb8-25">    x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.linspace(Y.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, Y.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">500</span>)</span>
<span id="cb8-26"></span>
<span id="cb8-27">    plt.plot(</span>
<span id="cb8-28">        x,</span>
<span id="cb8-29">        stats.norm.pdf(x, loc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, scale<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>),</span>
<span id="cb8-30">        <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r--'</span>,</span>
<span id="cb8-31">        linewidth<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb8-32">        label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r'</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$\m</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">athcal{N}</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">(</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">0,1</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">)</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;"> PDF'</span></span>
<span id="cb8-33">    )</span>
<span id="cb8-34"></span>
<span id="cb8-35">    plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Output value"</span>)</span>
<span id="cb8-36">    plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Density"</span>)</span>
<span id="cb8-37">    plt.title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Distribution of Network Outputs"</span>)</span>
<span id="cb8-38">    plt.legend()</span>
<span id="cb8-39">    plt.grid(alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>)</span>
<span id="cb8-40">    plt.show()</span>
<span id="cb8-41"></span>
<span id="cb8-42"></span>
<span id="cb8-43"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># -----------------------</span></span>
<span id="cb8-44"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Example setup</span></span>
<span id="cb8-45"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># -----------------------</span></span>
<span id="cb8-46">n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span></span>
<span id="cb8-47">hidden <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span></span>
<span id="cb8-48"></span>
<span id="cb8-49">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.rand(n) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span></span>
<span id="cb8-50">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X.unsqueeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb8-51"></span>
<span id="cb8-52">X_train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X[:n<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>]</span>
<span id="cb8-53">X_test <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X[n<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>:]</span>
<span id="cb8-54"></span>
<span id="cb8-55">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.nn.Sequential(</span>
<span id="cb8-56">    torch.nn.Linear(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, hidden),</span>
<span id="cb8-57">    torch.nn.ReLU(),</span>
<span id="cb8-58">    torch.nn.Linear(hidden, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb8-59">)</span>
<span id="cb8-60"></span>
<span id="cb8-61"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Before training</span></span>
<span id="cb8-62">plot_output_distribution(model, X_test)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-7-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>We can see the output isn’t very Gaussian. Now lets train the model using the Epps-Pulley Regularizer,</p>
<div id="2528c5f8" class="cell" data-execution_count="9">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> epps_pulley_1d(X, k):</span>
<span id="cb9-2">    t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.linspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, k, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>X.device)</span>
<span id="cb9-3"></span>
<span id="cb9-4">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># theoretical CF of N(0,1)</span></span>
<span id="cb9-5">    phi_gaussian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.exp(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> t<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb9-6"></span>
<span id="cb9-7">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># empirical CF</span></span>
<span id="cb9-8">    X_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X.unsqueeze(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> t</span>
<span id="cb9-9">    phi_emp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.mean(torch.exp(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">1j</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> X_t), dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb9-10"></span>
<span id="cb9-11">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># weighted squared error</span></span>
<span id="cb9-12">    err <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> phi_gaussian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">abs</span>(phi_emp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> phi_gaussian)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span></span>
<span id="cb9-13"></span>
<span id="cb9-14">    T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.trapz(err, t)</span>
<span id="cb9-15"></span>
<span id="cb9-16">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> T</span>
<span id="cb9-17"></span>
<span id="cb9-18">epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2000</span></span>
<span id="cb9-19"></span>
<span id="cb9-20">optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.optim.Adam(model.parameters(), lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>)</span>
<span id="cb9-21">loss_fn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> x: epps_pulley_1d(x, k<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">17</span>)</span>
<span id="cb9-22"></span>
<span id="cb9-23"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> epoch <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(epochs):</span>
<span id="cb9-24">    optimizer.zero_grad()</span>
<span id="cb9-25">    loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loss_fn(model(X_train))</span>
<span id="cb9-26">    loss.backward()</span>
<span id="cb9-27">    optimizer.step()</span>
<span id="cb9-28">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> (epoch<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> (epochs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>:</span>
<span id="cb9-29">      <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Epoch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epoch<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>item()<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb9-30"></span>
<span id="cb9-31">plot_output_distribution(model, X_test)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Epoch 200/2000, Loss: 0.0006097470759414136
Epoch 400/2000, Loss: 0.0005753693985752761
Epoch 600/2000, Loss: 0.0005708417156711221
Epoch 800/2000, Loss: 0.0005661727045662701
Epoch 1000/2000, Loss: 0.0005581520381383598
Epoch 1200/2000, Loss: 0.0005494060460478067
Epoch 1400/2000, Loss: 0.0005422477843239903
Epoch 1600/2000, Loss: 0.0005335460300557315
Epoch 1800/2000, Loss: 0.0005124190938659012
Epoch 2000/2000, Loss: 0.0004806756041944027</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-8-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>Now we can see the trained model’s output is more Gaussian.</p>
</section>
</section>
<section id="multiple-dimensions" class="level1">
<h1>Multiple Dimensions</h1>
<section id="multidimensional-gaussian-regularization" class="level2">
<h2 class="anchored" data-anchor-id="multidimensional-gaussian-regularization">Multidimensional Gaussian Regularization</h2>
<p>It’s not immediately clear how to apply this in higher dimensions without running into the curse of dimensionality. Naively approximating the integral in higher dimensions would require evaluating it at a number of points that scales exponentially with dimension. In the 1D example we used <img src="https://latex.codecogs.com/png.latex?17"> points. We want to avoid 2D requiring <img src="https://latex.codecogs.com/png.latex?17%5E2"> points, 3D requiring <img src="https://latex.codecogs.com/png.latex?17%5E3">, and so on.</p>
<p>The paper’s idea is instead to sample directions in the <img src="https://latex.codecogs.com/png.latex?n">-dimensional space, project the <img src="https://latex.codecogs.com/png.latex?n">-dimensional points onto these directions, and then compute the corresponding univariate test statistics.</p>
<p>For a multivariate Gaussian random variable</p>
<p><img src="https://latex.codecogs.com/png.latex?%0AX%20%5Csim%20%5Cmathcal%7BN%7D(0,%20I)%0A"></p>
<p>any linear projection <img src="https://latex.codecogs.com/png.latex?u%5ET%20X"> is also Gaussian. And when <img src="https://latex.codecogs.com/png.latex?u"> is a unit vector <img src="https://latex.codecogs.com/png.latex?%5C%7Cu%5C%7C=1">, the variance remains unchanged:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Ctext%7BVar%7D(u%5ET%20X)=u%5ET%20I%20u%20=%201%0A"></p>
<p>This means every projection of a standard isotropic Gaussian looks like:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0Au%5ET%20X%20%5Csim%20%5Cmathcal%7BN%7D(0,1)%0A"></p>
<p>The converse is what makes this useful: by the Cramér–Wold theorem, if <strong>every</strong> one-dimensional projection of a distribution is standard Gaussian, then the full multivariate distribution must itself be a standard(isotropic) multivariate Gaussian.</p>
<p>This gives us a scalable approximation strategy: instead of testing Gaussianity over the full <img src="https://latex.codecogs.com/png.latex?n">-dimensional space, we test many random one-dimensional projections.</p>
<p>That’s the core idea behind SIGReg: approximate high-dimensional Gaussian regularization using randomized projections rather than exponentially expensive multidimensional integration.</p>
<p>Here I’ll demonstrate this in two dimensions. First, we’ll inspect scatter plots of 2D data, then choose a random direction on the 2D unit circle, project the points onto that direction, and examine the resulting histograms. The key observation is that these projections reduce the problem back to one dimension, where we can run the same Epps–Pulley test as before.</p>
<div id="18f370b5" class="cell" data-execution_count="10">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1">n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">400</span></span>
<span id="cb11-2"></span>
<span id="cb11-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Generate samples</span></span>
<span id="cb11-4">X_uniform_2d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.rand(n, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span></span>
<span id="cb11-5">X_gaussian_2d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.randn(n, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb11-6"></span>
<span id="cb11-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Random unit direction</span></span>
<span id="cb11-8">v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.randn(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb11-9">v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> np.linalg.norm(v)</span>
<span id="cb11-10"></span>
<span id="cb11-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Projections</span></span>
<span id="cb11-12">proj_uniform <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X_uniform_2d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> v</span>
<span id="cb11-13">proj_gaussian <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X_gaussian_2d <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> v</span>
<span id="cb11-14"></span>
<span id="cb11-15"></span>
<span id="cb11-16"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ----------------------------</span></span>
<span id="cb11-17"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 1. Scatter plot with direction</span></span>
<span id="cb11-18"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ----------------------------</span></span>
<span id="cb11-19">plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>))</span>
<span id="cb11-20"></span>
<span id="cb11-21">plt.scatter(</span>
<span id="cb11-22">    X_uniform_2d[:,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>],</span>
<span id="cb11-23">    X_uniform_2d[:,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb11-24">    alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>,</span>
<span id="cb11-25">    label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Uniform"</span></span>
<span id="cb11-26">)</span>
<span id="cb11-27"></span>
<span id="cb11-28">plt.scatter(</span>
<span id="cb11-29">    X_gaussian_2d[:,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>],</span>
<span id="cb11-30">    X_gaussian_2d[:,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb11-31">    alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>,</span>
<span id="cb11-32">    label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Gaussian"</span></span>
<span id="cb11-33">)</span>
<span id="cb11-34"></span>
<span id="cb11-35">plt.arrow(</span>
<span id="cb11-36">    <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,</span>
<span id="cb11-37">    <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>v[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>v[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb11-38">    width<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.05</span>,</span>
<span id="cb11-39">    color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"black"</span>,</span>
<span id="cb11-40">    length_includes_head<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span></span>
<span id="cb11-41">)</span>
<span id="cb11-42"></span>
<span id="cb11-43">plt.text(</span>
<span id="cb11-44">    <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>v[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>],</span>
<span id="cb11-45">    <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>v[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb11-46">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"projection direction"</span>,</span>
<span id="cb11-47">    fontsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span></span>
<span id="cb11-48">)</span>
<span id="cb11-49"></span>
<span id="cb11-50">plt.axhline(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'gray'</span>, lw<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>)</span>
<span id="cb11-51">plt.axvline(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'gray'</span>, lw<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>)</span>
<span id="cb11-52">plt.legend()</span>
<span id="cb11-53">plt.title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"2D data with random projection direction"</span>)</span>
<span id="cb11-54">plt.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"equal"</span>)</span>
<span id="cb11-55">plt.show()</span>
<span id="cb11-56"></span>
<span id="cb11-57"></span>
<span id="cb11-58"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ----------------------------</span></span>
<span id="cb11-59"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 2. Projection geometry comparison</span></span>
<span id="cb11-60"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ----------------------------</span></span>
<span id="cb11-61"></span>
<span id="cb11-62">datasets <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [</span>
<span id="cb11-63">    (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Uniform Projection Geometry"</span>, X_uniform_2d),</span>
<span id="cb11-64">    (<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Gaussian Projection Geometry"</span>, X_gaussian_2d)</span>
<span id="cb11-65">]</span>
<span id="cb11-66"></span>
<span id="cb11-67"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> title, X_subset <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> datasets:</span>
<span id="cb11-68">    plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>))</span>
<span id="cb11-69">    plt.scatter(X_subset[:,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_subset[:,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb11-70"></span>
<span id="cb11-71">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># projection line</span></span>
<span id="cb11-72">    t_line <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.linspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>)</span>
<span id="cb11-73">    plt.plot(</span>
<span id="cb11-74">        t_line<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>v[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>],</span>
<span id="cb11-75">        t_line<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>v[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb11-76">        <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'k--'</span>,</span>
<span id="cb11-77">        label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'projection line'</span></span>
<span id="cb11-78">    )</span>
<span id="cb11-79"></span>
<span id="cb11-80">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># projection segments</span></span>
<span id="cb11-81">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> x <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> X_subset:</span>
<span id="cb11-82">        proj <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> v) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> v</span>
<span id="cb11-83">        plt.plot(</span>
<span id="cb11-84">            [x[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], proj[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]],</span>
<span id="cb11-85">            [x[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], proj[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]],</span>
<span id="cb11-86">            <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r-'</span>,</span>
<span id="cb11-87">            alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.4</span></span>
<span id="cb11-88">        )</span>
<span id="cb11-89"></span>
<span id="cb11-90">    plt.legend()</span>
<span id="cb11-91">    plt.title(title)</span>
<span id="cb11-92">    plt.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"equal"</span>)</span>
<span id="cb11-93">    plt.show()</span>
<span id="cb11-94"></span>
<span id="cb11-95"></span>
<span id="cb11-96"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ----------------------------</span></span>
<span id="cb11-97"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 3. Histogram comparison</span></span>
<span id="cb11-98"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># ----------------------------</span></span>
<span id="cb11-99">x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.linspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">400</span>)</span>
<span id="cb11-100"></span>
<span id="cb11-101">plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>))</span>
<span id="cb11-102"></span>
<span id="cb11-103">plt.hist(</span>
<span id="cb11-104">    proj_uniform,</span>
<span id="cb11-105">    bins<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">30</span>,</span>
<span id="cb11-106">    density<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb11-107">    alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.6</span>,</span>
<span id="cb11-108">    label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Uniform projected"</span></span>
<span id="cb11-109">)</span>
<span id="cb11-110"></span>
<span id="cb11-111">plt.hist(</span>
<span id="cb11-112">    proj_gaussian,</span>
<span id="cb11-113">    bins<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">30</span>,</span>
<span id="cb11-114">    density<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb11-115">    alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.6</span>,</span>
<span id="cb11-116">    label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Gaussian projected"</span></span>
<span id="cb11-117">)</span>
<span id="cb11-118"></span>
<span id="cb11-119">plt.plot(</span>
<span id="cb11-120">    x,</span>
<span id="cb11-121">    stats.norm.pdf(x),</span>
<span id="cb11-122">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'k--'</span>,</span>
<span id="cb11-123">    lw<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb11-124">    label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'N(0,1) density'</span></span>
<span id="cb11-125">)</span>
<span id="cb11-126"></span>
<span id="cb11-127">plt.legend()</span>
<span id="cb11-128">plt.title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"1D projections along a random direction"</span>)</span>
<span id="cb11-129">plt.show()</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-9-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-9-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-9-output-3.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-9-output-4.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
</section>
<section id="sigreg" class="level2">
<h2 class="anchored" data-anchor-id="sigreg">SIGReg</h2>
<p>SIGReg functions the same as above except takes the mean of <img src="https://latex.codecogs.com/png.latex?k"> slices. See the code block below which is similar to the code in the paper.</p>
<div id="d721a524" class="cell" data-execution_count="11">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> SIGReg(x, num_slices<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, k<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">17</span>):</span>
<span id="cb12-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># x: (N, D) samples</span></span>
<span id="cb12-3">    N, D <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.shape</span>
<span id="cb12-4">    device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.device</span>
<span id="cb12-5"></span>
<span id="cb12-6">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- Projection directions ---</span></span>
<span id="cb12-7">    A <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(D, num_slices, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)</span>
<span id="cb12-8">    A <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/=</span> A.norm(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># normalize columns → unit directions</span></span>
<span id="cb12-9"></span>
<span id="cb12-10">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Project to 1D: shape → (N, num_slices)</span></span>
<span id="cb12-11">    X_proj <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> A</span>
<span id="cb12-12"></span>
<span id="cb12-13">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># --- Integration points ---</span></span>
<span id="cb12-14">    t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.linspace(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, k, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (k,)</span></span>
<span id="cb12-15">    phi_normal <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.exp(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> t<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)          <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (k,)</span></span>
<span id="cb12-16">    weight <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> phi_normal                          <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Gaussian window</span></span>
<span id="cb12-17"></span>
<span id="cb12-18">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Broadcast shapes: (N, M, 1) ⋅ (1, 1, k)</span></span>
<span id="cb12-19">    X_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X_proj.unsqueeze(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> t</span>
<span id="cb12-20"></span>
<span id="cb12-21">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Empirical characteristic function across samples</span></span>
<span id="cb12-22">    ecf <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.exp(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">1j</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> X_t).mean(dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (M, k)</span></span>
<span id="cb12-23"></span>
<span id="cb12-24">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Squared difference</span></span>
<span id="cb12-25">    diff_sq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (ecf <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> phi_normal).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">abs</span>()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># (M, k)</span></span>
<span id="cb12-26"></span>
<span id="cb12-27">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Weighted integration for all projections → shape (M,)</span></span>
<span id="cb12-28">    per_direction_T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.trapz(diff_sq <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> weight, t, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> N</span>
<span id="cb12-29"></span>
<span id="cb12-30">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># GLOBAL aggregation — MEAN instead of MAX</span></span>
<span id="cb12-31">    T_global <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> per_direction_T.mean()</span>
<span id="cb12-32"></span>
<span id="cb12-33">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">return</span> T_global</span></code></pre></div></div>
</details>
</div>
<p>Now lets make a network’s 10-dimensional output standard gaussian. We’ll plot a histogram in the direction of the first dimension, before and after training on the regularizer.</p>
<div id="ad4eec8a" class="cell" data-execution_count="12">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1">device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.device(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cuda"</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> torch.cuda.is_available() <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">else</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cpu"</span>)</span>
<span id="cb13-2"></span>
<span id="cb13-3">n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span></span>
<span id="cb13-4">hidden <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span></span>
<span id="cb13-5">D <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span></span>
<span id="cb13-6">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.rand((n,D))<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>   <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># uniform(0,6)</span></span>
<span id="cb13-7"></span>
<span id="cb13-8">X_train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X[:n<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>].to(device)</span>
<span id="cb13-9">X_test <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X[n<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>:].to(device)</span>
<span id="cb13-10"></span>
<span id="cb13-11">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.nn.Sequential(</span>
<span id="cb13-12">    torch.nn.Linear(D, hidden),</span>
<span id="cb13-13">    torch.nn.ReLU(),</span>
<span id="cb13-14">    torch.nn.Linear(hidden, D)</span>
<span id="cb13-15">).to(device)</span>
<span id="cb13-16"></span>
<span id="cb13-17">Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(X_test)</span>
<span id="cb13-18"></span>
<span id="cb13-19"><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">def</span> plot_output_distribution(model, X, bins<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span>):</span>
<span id="cb13-20">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb13-21"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Plot model output histogram against standard normal PDF.</span></span>
<span id="cb13-22"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    </span></span>
<span id="cb13-23"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Args:</span></span>
<span id="cb13-24"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        model: PyTorch model</span></span>
<span id="cb13-25"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        X: input tensor</span></span>
<span id="cb13-26"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        bins: histogram bins</span></span>
<span id="cb13-27"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    """</span></span>
<span id="cb13-28">    model.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">eval</span>()</span>
<span id="cb13-29"></span>
<span id="cb13-30">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb13-31">        Y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(X).squeeze().cpu().numpy()[:,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]</span>
<span id="cb13-32"></span>
<span id="cb13-33">    plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>,<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>))</span>
<span id="cb13-34"></span>
<span id="cb13-35">    plt.hist(</span>
<span id="cb13-36">        Y,</span>
<span id="cb13-37">        bins<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>bins,</span>
<span id="cb13-38">        density<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>,</span>
<span id="cb13-39">        alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.6</span>,</span>
<span id="cb13-40">        label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Network outputs"</span></span>
<span id="cb13-41">    )</span>
<span id="cb13-42"></span>
<span id="cb13-43">    x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.linspace(Y.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, Y.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">500</span>)</span>
<span id="cb13-44"></span>
<span id="cb13-45">    plt.plot(</span>
<span id="cb13-46">        x,</span>
<span id="cb13-47">        stats.norm.pdf(x, loc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, scale<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>),</span>
<span id="cb13-48">        <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r--'</span>,</span>
<span id="cb13-49">        linewidth<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb13-50">        label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r'</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$\m</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">athcal{N}</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">(</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">0,1</span><span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">)</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">$</span><span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;"> PDF'</span></span>
<span id="cb13-51">    )</span>
<span id="cb13-52"></span>
<span id="cb13-53">    plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Output value"</span>)</span>
<span id="cb13-54">    plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Density"</span>)</span>
<span id="cb13-55">    plt.title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Distribution of Network Outputs"</span>)</span>
<span id="cb13-56">    plt.legend()</span>
<span id="cb13-57">    plt.grid(alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.3</span>)</span>
<span id="cb13-58">    plt.show()</span>
<span id="cb13-59"></span>
<span id="cb13-60"></span>
<span id="cb13-61">plot_output_distribution(model, X_test)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-11-output-1.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>Again, not very Gaussian. Lets train using SIGReg:</p>
<div id="43d5385c" class="cell" data-execution_count="13">
<details class="code-fold">
<summary>Code</summary>
<div class="code-copy-outer-scaffold"><div class="sourceCode cell-code" id="cb14" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1">epochs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span></span>
<span id="cb14-2"></span>
<span id="cb14-3">optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.optim.Adam(model.parameters(), lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.01</span>)</span>
<span id="cb14-4">loss_fn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">lambda</span> x: SIGReg(x)</span>
<span id="cb14-5"></span>
<span id="cb14-6"><span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">for</span> epoch <span class="kw" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(epochs):</span>
<span id="cb14-7">    optimizer.zero_grad()</span>
<span id="cb14-8">    loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loss_fn(model(X_train))</span>
<span id="cb14-9">    loss.backward()</span>
<span id="cb14-10">    optimizer.step()</span>
<span id="cb14-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-weight: bold;
font-style: inherit;">if</span> (epoch<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> (epochs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>:</span>
<span id="cb14-12">      <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Epoch </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epoch<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">/</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epochs<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, Loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>item()<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb14-13"></span>
<span id="cb14-14">Y<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>model(X_test)</span>
<span id="cb14-15">plot_output_distribution(model, X_test)</span></code></pre></div></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>Epoch 100/1000, Loss: 2.2469592094421387
Epoch 200/1000, Loss: 0.9632902145385742
Epoch 300/1000, Loss: 0.7503044605255127
Epoch 400/1000, Loss: 0.548032283782959
Epoch 500/1000, Loss: 0.42803165316581726
Epoch 600/1000, Loss: 0.4888719618320465
Epoch 700/1000, Loss: 0.4376848340034485
Epoch 800/1000, Loss: 0.37319087982177734
Epoch 900/1000, Loss: 0.3438933491706848
Epoch 1000/1000, Loss: 0.2580392360687256</code></pre>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><img src="https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/index_files/figure-html/cell-12-output-2.png" class="img-fluid figure-img"></p>
</figure>
</div>
</div>
</div>
<p>And now after training, the output is more Gaussian. So adding this regularizer to your objective function encourages Gaussian latent embeddings.</p>
<p>In a future blog post, I’d like to train a small JEPA model using SIGReg as in “LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels”<span class="citation" data-cites="maes2026">(Maes et al. 2026)</span>.</p>



</section>
</section>

<div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-bibliography"><h2 class="anchored quarto-appendix-heading">References</h2><div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-balestriero2025" class="csl-entry">
Balestriero, Randall, and Yann LeCun. 2025. <em>LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics</em>. <a href="https://arxiv.org/abs/2511.08544">https://arxiv.org/abs/2511.08544</a>.
</div>
<div id="ref-maes2026" class="csl-entry">
Maes, Lucas, Quentin Le Lidec, Damien Scieur, Yann LeCun, and Randall Balestriero. 2026. <em>LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels</em>. <a href="https://arxiv.org/abs/2603.19312">https://arxiv.org/abs/2603.19312</a>.
</div>
</div></section></div> ]]></description>
  <category>ai</category>
  <category>code</category>
  <category>jepa</category>
  <guid>https://shonczinner.github.io/posts/sigreg-sketched-isotropic-gaussian-regularization/</guid>
  <pubDate>Sun, 26 Apr 2026 00:00:00 GMT</pubDate>
</item>
<item>
  <title>Motivation To Write</title>
  <dc:creator>Shon Czinner</dc:creator>
  <link>https://shonczinner.github.io/posts/welcome/</link>
  <description><![CDATA[ 




<p>I plan to write about AI, statistics, quantitative finance, economics, and more.</p>
<p>For my first post I’ve decided to write about why I’m writing.</p>
<ol type="1">
<li>Writing things down provides a reference for when I forget things.</li>
<li>Writing improves understanding.</li>
</ol>
<p>If other people find the things I write about interesting, that’s a bonus.</p>
<p>Feel free to get in touch if you find any mistakes!</p>



 ]]></description>
  <category>meta</category>
  <guid>https://shonczinner.github.io/posts/welcome/</guid>
  <pubDate>Sat, 25 Apr 2026 00:00:00 GMT</pubDate>
</item>
</channel>
</rss>
