Skip to content

Commit 5ae7c2c

Browse files
committed
1 parent 782ee19 commit 5ae7c2c

9 files changed

Lines changed: 402 additions & 185 deletions

File tree

.doctrees/environment.pickle

11.7 KB
Binary file not shown.

.doctrees/zoo/einet.doctree

314 Bytes
Binary file not shown.

_modules/spflow/learn/expectation_maximization.html

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,21 @@ <h1>Source code for spflow.learn.expectation_maximization</h1><div class="highli
286286
<span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
287287

288288

289+
<span class="k">def</span><span class="w"> </span><span class="nf">_retain_cached_log_likelihood_grads</span><span class="p">(</span><span class="n">cache</span><span class="p">:</span> <span class="n">Cache</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
290+
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Retain gradients for cached non-leaf likelihood tensors consumed by EM.&quot;&quot;&quot;</span>
291+
<span class="k">for</span> <span class="n">lls</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">[</span><span class="s2">&quot;log_likelihood&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
292+
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_tensor</span><span class="p">(</span><span class="n">lls</span><span class="p">)</span> <span class="ow">and</span> <span class="n">lls</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">:</span>
293+
<span class="n">lls</span><span class="o">.</span><span class="n">retain_grad</span><span class="p">()</span>
294+
295+
296+
<span class="k">def</span><span class="w"> </span><span class="nf">_backward_accumulated_log_likelihood</span><span class="p">(</span><span class="n">acc_ll</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
297+
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Backpropagate one EM step without retaining the graph.&quot;&quot;&quot;</span>
298+
<span class="k">if</span> <span class="ow">not</span> <span class="n">acc_ll</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">:</span>
299+
<span class="k">return</span>
300+
301+
<span class="n">acc_ll</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
302+
303+
289304
<div class="viewcode-block" id="expectation_maximization">
290305
<a class="viewcode-back" href="../../../api/learning.html#spflow.learn.expectation_maximization.expectation_maximization">[docs]</a>
291306
<span class="k">def</span><span class="w"> </span><span class="nf">expectation_maximization</span><span class="p">(</span>
@@ -327,14 +342,8 @@ <h1>Source code for spflow.learn.expectation_maximization</h1><div class="highli
327342
<span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
328343
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Step </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">: Average log-likelihood: </span><span class="si">{</span><span class="n">avg_ll</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
329344

330-
<span class="c1"># retain gradients for all module log-likelihoods</span>
331-
<span class="k">for</span> <span class="n">lls</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">[</span><span class="s2">&quot;log_likelihood&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
332-
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_tensor</span><span class="p">(</span><span class="n">lls</span><span class="p">)</span> <span class="ow">and</span> <span class="n">lls</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">:</span>
333-
<span class="n">lls</span><span class="o">.</span><span class="n">retain_grad</span><span class="p">()</span>
334-
335-
<span class="c1"># compute gradients (if there are differentiable parameters to begin with)</span>
336-
<span class="k">if</span> <span class="n">acc_ll</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">:</span>
337-
<span class="n">acc_ll</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
345+
<span class="n">_retain_cached_log_likelihood_grads</span><span class="p">(</span><span class="n">cache</span><span class="p">)</span>
346+
<span class="n">_backward_accumulated_log_likelihood</span><span class="p">(</span><span class="n">acc_ll</span><span class="p">)</span>
338347

339348
<span class="c1"># recursively perform expectation maximization</span>
340349
<span class="n">module</span><span class="o">.</span><span class="n">_expectation_maximization_step</span><span class="p">(</span>
@@ -393,12 +402,8 @@ <h1>Source code for spflow.learn.expectation_maximization</h1><div class="highli
393402
<span class="n">epoch_ll</span> <span class="o">=</span> <span class="n">epoch_ll</span> <span class="o">+</span> <span class="n">acc_ll</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
394403
<span class="n">num_samples</span> <span class="o">+=</span> <span class="n">batch_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
395404

396-
<span class="k">for</span> <span class="n">lls</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">[</span><span class="s2">&quot;log_likelihood&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
397-
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_tensor</span><span class="p">(</span><span class="n">lls</span><span class="p">)</span> <span class="ow">and</span> <span class="n">lls</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">:</span>
398-
<span class="n">lls</span><span class="o">.</span><span class="n">retain_grad</span><span class="p">()</span>
399-
400-
<span class="k">if</span> <span class="n">acc_ll</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">:</span>
401-
<span class="n">acc_ll</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
405+
<span class="n">_retain_cached_log_likelihood_grads</span><span class="p">(</span><span class="n">cache</span><span class="p">)</span>
406+
<span class="n">_backward_accumulated_log_likelihood</span><span class="p">(</span><span class="n">acc_ll</span><span class="p">)</span>
402407

403408
<span class="n">module</span><span class="o">.</span><span class="n">_expectation_maximization_step</span><span class="p">(</span>
404409
<span class="n">data</span><span class="o">=</span><span class="n">batch_data</span><span class="p">,</span>

0 commit comments

Comments
 (0)