@@ -286,6 +286,21 @@ <h1>Source code for spflow.learn.expectation_maximization</h1><div class="highli
286286< span class ="n "> logger</ span > < span class ="o "> =</ span > < span class ="n "> logging</ span > < span class ="o "> .</ span > < span class ="n "> getLogger</ span > < span class ="p "> (</ span > < span class ="vm "> __name__</ span > < span class ="p "> )</ span >
287287
288288
289+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> _retain_cached_log_likelihood_grads</ span > < span class ="p "> (</ span > < span class ="n "> cache</ span > < span class ="p "> :</ span > < span class ="n "> Cache</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
290+ < span class ="w "> </ span > < span class ="sd "> """Retain gradients for cached non-leaf likelihood tensors consumed by EM."""</ span >
291+ < span class ="k "> for</ span > < span class ="n "> lls</ span > < span class ="ow "> in</ span > < span class ="n "> cache</ span > < span class ="p "> [</ span > < span class ="s2 "> "log_likelihood"</ span > < span class ="p "> ]</ span > < span class ="o "> .</ span > < span class ="n "> values</ span > < span class ="p "> ():</ span >
292+ < span class ="k "> if</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> is_tensor</ span > < span class ="p "> (</ span > < span class ="n "> lls</ span > < span class ="p "> )</ span > < span class ="ow "> and</ span > < span class ="n "> lls</ span > < span class ="o "> .</ span > < span class ="n "> requires_grad</ span > < span class ="p "> :</ span >
293+ < span class ="n "> lls</ span > < span class ="o "> .</ span > < span class ="n "> retain_grad</ span > < span class ="p "> ()</ span >
294+
295+
296+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> _backward_accumulated_log_likelihood</ span > < span class ="p "> (</ span > < span class ="n "> acc_ll</ span > < span class ="p "> :</ span > < span class ="n "> Tensor</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
297+ < span class ="w "> </ span > < span class ="sd "> """Backpropagate one EM step without retaining the graph."""</ span >
298+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> acc_ll</ span > < span class ="o "> .</ span > < span class ="n "> requires_grad</ span > < span class ="p "> :</ span >
299+ < span class ="k "> return</ span >
300+
301+ < span class ="n "> acc_ll</ span > < span class ="o "> .</ span > < span class ="n "> backward</ span > < span class ="p "> ()</ span >
302+
303+
289304< div class ="viewcode-block " id ="expectation_maximization ">
290305< a class ="viewcode-back " href ="../../../api/learning.html#spflow.learn.expectation_maximization.expectation_maximization "> [docs]</ a >
291306< span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> expectation_maximization</ span > < span class ="p "> (</ span >
@@ -327,14 +342,8 @@ <h1>Source code for spflow.learn.expectation_maximization</h1><div class="highli
327342 < span class ="k "> if</ span > < span class ="n "> verbose</ span > < span class ="p "> :</ span >
328343 < span class ="n "> logger</ span > < span class ="o "> .</ span > < span class ="n "> info</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Step </ span > < span class ="si "> {</ span > < span class ="n "> step</ span > < span class ="si "> }</ span > < span class ="s2 "> : Average log-likelihood: </ span > < span class ="si "> {</ span > < span class ="n "> avg_ll</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
329344
330- < span class ="c1 "> # retain gradients for all module log-likelihoods</ span >
331- < span class ="k "> for</ span > < span class ="n "> lls</ span > < span class ="ow "> in</ span > < span class ="n "> cache</ span > < span class ="p "> [</ span > < span class ="s2 "> "log_likelihood"</ span > < span class ="p "> ]</ span > < span class ="o "> .</ span > < span class ="n "> values</ span > < span class ="p "> ():</ span >
332- < span class ="k "> if</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> is_tensor</ span > < span class ="p "> (</ span > < span class ="n "> lls</ span > < span class ="p "> )</ span > < span class ="ow "> and</ span > < span class ="n "> lls</ span > < span class ="o "> .</ span > < span class ="n "> requires_grad</ span > < span class ="p "> :</ span >
333- < span class ="n "> lls</ span > < span class ="o "> .</ span > < span class ="n "> retain_grad</ span > < span class ="p "> ()</ span >
334-
335- < span class ="c1 "> # compute gradients (if there are differentiable parameters to begin with)</ span >
336- < span class ="k "> if</ span > < span class ="n "> acc_ll</ span > < span class ="o "> .</ span > < span class ="n "> requires_grad</ span > < span class ="p "> :</ span >
337- < span class ="n "> acc_ll</ span > < span class ="o "> .</ span > < span class ="n "> backward</ span > < span class ="p "> (</ span > < span class ="n "> retain_graph</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
345+ < span class ="n "> _retain_cached_log_likelihood_grads</ span > < span class ="p "> (</ span > < span class ="n "> cache</ span > < span class ="p "> )</ span >
346+ < span class ="n "> _backward_accumulated_log_likelihood</ span > < span class ="p "> (</ span > < span class ="n "> acc_ll</ span > < span class ="p "> )</ span >
338347
339348 < span class ="c1 "> # recursively perform expectation maximization</ span >
340349 < span class ="n "> module</ span > < span class ="o "> .</ span > < span class ="n "> _expectation_maximization_step</ span > < span class ="p "> (</ span >
@@ -393,12 +402,8 @@ <h1>Source code for spflow.learn.expectation_maximization</h1><div class="highli
393402 < span class ="n "> epoch_ll</ span > < span class ="o "> =</ span > < span class ="n "> epoch_ll</ span > < span class ="o "> +</ span > < span class ="n "> acc_ll</ span > < span class ="o "> .</ span > < span class ="n "> detach</ span > < span class ="p "> ()</ span >
394403 < span class ="n "> num_samples</ span > < span class ="o "> +=</ span > < span class ="n "> batch_data</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
395404
396- < span class ="k "> for</ span > < span class ="n "> lls</ span > < span class ="ow "> in</ span > < span class ="n "> cache</ span > < span class ="p "> [</ span > < span class ="s2 "> "log_likelihood"</ span > < span class ="p "> ]</ span > < span class ="o "> .</ span > < span class ="n "> values</ span > < span class ="p "> ():</ span >
397- < span class ="k "> if</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> is_tensor</ span > < span class ="p "> (</ span > < span class ="n "> lls</ span > < span class ="p "> )</ span > < span class ="ow "> and</ span > < span class ="n "> lls</ span > < span class ="o "> .</ span > < span class ="n "> requires_grad</ span > < span class ="p "> :</ span >
398- < span class ="n "> lls</ span > < span class ="o "> .</ span > < span class ="n "> retain_grad</ span > < span class ="p "> ()</ span >
399-
400- < span class ="k "> if</ span > < span class ="n "> acc_ll</ span > < span class ="o "> .</ span > < span class ="n "> requires_grad</ span > < span class ="p "> :</ span >
401- < span class ="n "> acc_ll</ span > < span class ="o "> .</ span > < span class ="n "> backward</ span > < span class ="p "> (</ span > < span class ="n "> retain_graph</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
405+ < span class ="n "> _retain_cached_log_likelihood_grads</ span > < span class ="p "> (</ span > < span class ="n "> cache</ span > < span class ="p "> )</ span >
406+ < span class ="n "> _backward_accumulated_log_likelihood</ span > < span class ="p "> (</ span > < span class ="n "> acc_ll</ span > < span class ="p "> )</ span >
402407
403408 < span class ="n "> module</ span > < span class ="o "> .</ span > < span class ="n "> _expectation_maximization_step</ span > < span class ="p "> (</ span >
404409 < span class ="n "> data</ span > < span class ="o "> =</ span > < span class ="n "> batch_data</ span > < span class ="p "> ,</ span >
0 commit comments