-
-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy pathTypechecker.ml
More file actions
2262 lines (2110 loc) · 91.3 KB
/
Typechecker.ml
File metadata and controls
2262 lines (2110 loc) · 91.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
(** a type/semantic checker for Stan ASTs
Functions which begin with "check_" return a typed version of their input
Functions which begin with "verify_" return unit if a check succeeds, or
else throw an TypecheckerException exception. Other functions which begin
with "infer"/"calculate" vary. Usually they return a value, but a few do
have error conditions.
All TypecheckerException exceptions are caught by check_program which turns
the ast or exception into a Result.t for external usage
A type environment (Env.t) is used to hold variables and functions,
including stan math functions. This is a functional map, meaning it is
handled immutably. *)
open Core
open Core.Poly
open Middle
open Ast
module Env = Environment
(** Internal (private) exception used for errors. *)
exception TypecheckerException of Semantic_error.t
(* we only allow errors raised by this function *)
let error e = raise (TypecheckerException e)
(* warnings are built up in a list *)
let warnings : Warnings.t list ref = ref []
let add_warning (span : Location_span.t) (message : string) =
warnings := (span, message) :: !warnings
let attach_warnings x = (x, List.rev !warnings)
let requires_higher_order_autodiff = ref []
let needs_higher_order_autodiff fn =
requires_higher_order_autodiff := fn :: !requires_higher_order_autodiff
(* model name - don't love this here *)
let model_name = ref ""
type function_indicator =
| NotInFunction
| NonReturning of unit Fun_kind.suffix
| Returning of unit Fun_kind.suffix * UnsizedType.t
(* Record structure holding flags and other markers about context to be used for
error reporting. *)
type context_flags_record =
{ current_block: Env.originblock
; in_toplevel_decl: bool
; containing_function: function_indicator
; loop_depth: int }
let in_rng_function cf =
match cf.containing_function with
| NonReturning FnRng | Returning (FnRng, _) -> true
| _ -> false
let in_lp_function cf =
match cf.containing_function with
| NonReturning FnTarget | Returning (FnTarget, _) -> true
| _ -> false
let in_jacobian_function cf =
match cf.containing_function with
| NonReturning FnJacobian | Returning (FnJacobian, _) -> true
| _ -> false
let in_udf_distribution cf =
match cf.containing_function with
| NonReturning (FnLpdf ())
|Returning (FnLpdf (), _)
|NonReturning (FnLpmf ())
|Returning (FnLpmf (), _) ->
true
| _ -> false
let context block =
{ current_block= block
; in_toplevel_decl= false
; containing_function= NotInFunction
; loop_depth= 0 }
let rec calculate_autodifftype cf origin ut =
let ut, _ = UnsizedType.unwind_array_type ut in
match (origin, ut) with
| _, UTuple ts ->
UnsizedType.TupleAD (List.map ~f:(calculate_autodifftype cf origin) ts)
| Env.(Param | TParam | Model | Functions), _
when not (UnsizedType.is_discrete_type ut || cf.current_block = GQuant) ->
UnsizedType.AutoDiffable
| _, _ -> DataOnly
let arg_type x = (x.emeta.ad_level, x.emeta.type_)
let get_arg_types = List.map ~f:arg_type
let type_of_expr_typed ue = ue.emeta.type_
let has_int_type ue = ue.emeta.type_ = UInt
let has_int_array_type ue = ue.emeta.type_ = UArray UInt
let rec name_of_lval lv =
match lv.lval with
| LVariable id -> id.name
| LTupleProjection (lv, _) -> name_of_lval lv
| LIndexed (lv, _) -> name_of_lval lv
let has_int_or_real_type ue =
match ue.emeta.type_ with UInt | UReal -> true | _ -> false
(* -- General checks ---------------------------------------------- *)
let reserved_keywords =
(* parser stops most keywords currently in use, but we still have some extra
reserved for the future *)
[ "generated"; "quantities"; "transformed"; "repeat"; "until"; "then"; "true"
; "false"; "typedef"; "struct"; "var"; "export"; "extern"; "static"; "auto" ]
let verify_identifier id : unit =
if id.name = !model_name then
Semantic_error.ident_is_model_name id.id_loc id.name |> error;
if
String.is_suffix id.name ~suffix:"__"
|| List.mem reserved_keywords id.name ~equal:String.equal
then Semantic_error.ident_is_keyword id.id_loc id.name |> error
(** verify that the variable being declared is previous unused. allowed to
shadow StanLib *)
let verify_name_fresh_var loc tenv name =
if Utils.is_unnormalized_distribution name then
Semantic_error.ident_has_unnormalized_suffix loc name |> error
else
match
List.filter_map (Env.find tenv name) ~f:(function
| {kind= `Variable {location; _}; _} -> Some location
| _ -> None (* user variables can shadow function names *))
with
| [] -> ()
| prev :: _ -> Semantic_error.ident_in_use loc name prev |> error
(** verify that the variable being declared is previous unused. *)
let verify_name_fresh_udf loc tenv name =
if
(* variadic functions aren't overloadable due to their separate
typechecking *)
Stan_math_signatures.is_special_function_name name
then Semantic_error.ident_is_stanmath_name loc name |> error
else if Utils.is_unnormalized_distribution name then
Semantic_error.udf_is_unnormalized_fn loc name |> error
else
(* if a variable is already defined with this name - not really possible as
all functions are defined before data, but future-proofing is good *)
match
List.filter_map (Env.find tenv name) ~f:(function
| {kind= `Variable {location; _}; _} -> Some location
| _ -> None)
with
| [] -> ()
| prev :: _ -> Semantic_error.ident_in_use loc name prev |> error
(** Checks that a variable/function name:
- a function/identifier does not have the _lupdf/_lupmf suffix
- is not already in use (for now) *)
let verify_name_fresh tenv id ~is_udf =
if is_udf then verify_name_fresh_udf id.id_loc tenv id.name
else verify_name_fresh_var id.id_loc tenv id.name
let is_of_compatible_return_type rt1 srt2 =
UnsizedType.(
match (rt1, srt2) with
| Void, _ -> true
| ReturnType _, Complete -> true
| _ -> false)
(* -- Expressions ------------------------------------------------- *)
let check_ternary_if loc pe te fe =
let promote expr type_ ad_level =
if
(not (UnsizedType.equal expr.emeta.type_ type_))
|| UnsizedType.compare_autodifftype expr.emeta.ad_level ad_level <> 0
then
{ expr= Promotion (expr, (UnsizedType.internal_scalar type_, ad_level))
; emeta= {expr.emeta with type_; ad_level} }
else expr in
match
( pe.emeta.type_
, UnsizedType.common_type (te.emeta.type_, fe.emeta.type_)
, expr_ad_lub [pe; te; fe] )
with
| UInt, Some type_, Some ad_level when not (UnsizedType.is_fun_type type_) ->
mk_typed_expression
~expr:
(TernaryIf (pe, promote te type_ ad_level, promote fe type_ ad_level))
~ad_level ~type_ ~loc
| _, _, _ ->
Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_
fe.emeta.type_
|> error
let match_to_rt_option = function
| SignatureMismatch.UniqueMatch (rt, _, _, _) -> Some rt
| _ -> None
let stan_math_return_type name arg_tys =
match Stan_math_signatures.lookup_stan_math_variadic_function name with
| Some {return_type; _} -> Some (UnsizedType.ReturnType return_type)
| None when Stan_math_signatures.is_reduce_sum_fn name ->
Some (UnsizedType.ReturnType UReal)
| None ->
SignatureMismatch.matching_stanlib_function name arg_tys
|> match_to_rt_option
let operator_stan_math_return_type op arg_tys =
match (op, arg_tys) with
| Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] ->
Some (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion])
| IntDivide, _ -> None
| _ ->
Stan_math_signatures.operator_to_stan_math_fns op
|> List.filter_map ~f:(fun name ->
SignatureMismatch.matching_stanlib_function name arg_tys |> function
| SignatureMismatch.UniqueMatch (rt, _, p, _) -> Some (rt, p)
| _ -> None)
|> List.hd
let assignmentoperator_stan_math_return_type assop arg_tys =
(match assop with
| Operator.Divide ->
SignatureMismatch.matching_stanlib_function "divide" arg_tys
|> match_to_rt_option
| Plus | Minus | Times | EltTimes | EltDivide ->
operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst
| _ -> None)
|> Option.bind ~f:(function
| ReturnType rtype
when rtype = snd (List.hd_exn arg_tys)
&& not
((assop = Operator.EltTimes || assop = Operator.EltDivide)
&& UnsizedType.is_scalar_type rtype) ->
Some UnsizedType.Void
| _ -> None)
let check_binop loc op le re =
let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in
match (rt, expr_ad_lub [le; re]) with
| Some (ReturnType type_, [p1; p2]), Some ad_level ->
mk_typed_expression
~expr:(BinOp (Promotion.promote le p1, op, Promotion.promote re p2))
~ad_level ~type_ ~loc
| _ ->
Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_
|> error
let check_prefixop loc op te =
let rt = operator_stan_math_return_type op [arg_type te] in
match rt with
| Some (ReturnType type_, _) ->
mk_typed_expression
~expr:(PrefixOp (op, te))
~ad_level:te.emeta.ad_level ~type_ ~loc
| _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error
let check_postfixop loc op te =
let rt = operator_stan_math_return_type op [arg_type te] in
match rt with
| Some (ReturnType type_, _) ->
mk_typed_expression
~expr:(PostfixOp (te, op))
~ad_level:te.emeta.ad_level ~type_ ~loc
| _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error
let check_id cf loc tenv id =
match Env.find tenv (Utils.stdlib_distribution_name id.name) with
| [] ->
Semantic_error.ident_not_in_scope loc id.name
(Env.nearest_ident tenv id.name)
|> error
| {kind= `StanMath; _} :: _ ->
( calculate_autodifftype cf MathLibrary UMathLibraryFunction
, UnsizedType.UMathLibraryFunction )
| { kind=
`Variable
{origin= (Param | TParam | GQuant) as origin; location= prev; _}
; _ }
:: _
when cf.in_toplevel_decl ->
Semantic_error.non_data_variable_size_decl loc origin prev |> error
| _ :: _
when Utils.is_unnormalized_distribution id.name
&& not
((in_udf_distribution cf || in_lp_function cf)
|| cf.current_block = Model) ->
Semantic_error.invalid_unnormalized_fn loc |> error
| {kind= `Variable {origin; _}; type_; _} :: _ ->
(calculate_autodifftype cf origin type_, type_)
| { kind= `UserDefined _ | `UserDeclared _
; type_= UFun (args, rt, (FnLpdf _ | FnLpmf _), mem_pattern)
; _ }
:: _ ->
let type_ =
UnsizedType.UFun
(args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in
(calculate_autodifftype cf Functions type_, type_)
| {kind= `UserDefined _ | `UserDeclared _; type_; _} :: _ ->
(calculate_autodifftype cf Functions type_, type_)
let check_variable cf loc tenv id =
let ad_level, type_ = check_id cf loc tenv id in
mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc
let get_consistent_types type_ es =
let f state e =
Result.bind state ~f:(fun ty ->
match UnsizedType.common_type (ty, e.emeta.type_) with
| Some ty -> Ok ty
| None -> Error (ty, e.emeta)) in
List.fold ~init:(Ok type_) ~f es
|> Result.map ~f:(fun ty ->
let ad =
expr_ad_lub es |> Option.value_exn
(* correctness: Result.Ok case only contains tuples of same lengths,
expr_ad_lub cannot fail *) in
let promotions =
List.map (get_arg_types es)
~f:(Promotion.get_type_promotion_exn (ad, ty)) in
(ad, ty, promotions))
let check_texpression_is_tuple te name =
match (te.emeta.type_, te.emeta.ad_level) with
| UTuple ts, TupleAD ads -> List.zip_exn ads ts
| _ -> Semantic_error.tuple_expected te.emeta.loc name te.emeta.type_ |> error
let check_array_expr loc es =
match es with
| [] ->
(* NB: This is actually disallowed by parser *)
Semantic_error.empty_array loc |> error
| {emeta= {type_; _}; _} :: _ -> (
match get_consistent_types type_ es with
| Error (ty, meta) ->
Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error
| Ok (ad_level, type_, promotions) ->
let type_ = UnsizedType.UArray type_ in
mk_typed_expression
~expr:(ArrayExpr (Promotion.promote_list es promotions))
~ad_level ~type_ ~loc)
let check_rowvector loc es =
match es with
| {emeta= {type_= UnsizedType.URowVector; _}; _} :: _ -> (
match get_consistent_types URowVector es with
| Ok (ad_level, typ, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level
~type_:(if typ = UComplexRowVector then UComplexMatrix else UMatrix)
~loc
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error)
| {emeta= {type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> (
match get_consistent_types UComplexRowVector es with
| Ok (ad_level, _, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level ~type_:UComplexMatrix ~loc
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error)
| _ -> (
match get_consistent_types UReal es with
| Ok (ad_level, typ, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level
~type_:(if typ = UComplex then UComplexRowVector else URowVector)
~loc
| Error (_, meta) ->
Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error)
(* index checking *)
let indexing_type idx =
match idx with
| Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single
| _ -> `Multi
let is_multiindex i =
match indexing_type i with `Single -> false | `Multi -> true
let inferred_unsizedtype_of_indexed ~loc ut indices =
let rec aux type_ idcs =
let vec, rowvec, scalar =
if UnsizedType.is_complex_type type_ then
UnsizedType.(UComplexVector, UComplexRowVector, UComplex)
else (UVector, URowVector, UReal) in
match (type_, idcs) with
| _, [] -> type_
| UnsizedType.UArray type_, `Single :: tl -> aux type_ tl
| UArray type_, `Multi :: tl -> aux type_ tl |> UnsizedType.UArray
| (UVector | URowVector | UComplexRowVector | UComplexVector), [`Single]
|(UMatrix | UComplexMatrix), [`Single; `Single] ->
scalar
| ( ( UVector | URowVector | UMatrix | UComplexVector | UComplexMatrix
| UComplexRowVector )
, [`Multi] )
|(UMatrix | UComplexMatrix), [`Multi; `Multi] ->
type_
| (UMatrix | UComplexMatrix), ([`Single] | [`Single; `Multi]) -> rowvec
| (UMatrix | UComplexMatrix), [`Multi; `Single] -> vec
| (UMatrix | UComplexMatrix), _ :: _ :: _ :: _
|(UVector | URowVector | UComplexRowVector | UComplexVector), _ :: _ :: _
|( (UInt | UReal | UComplex | UFun _ | UMathLibraryFunction | UTuple _)
, _ :: _ ) ->
Semantic_error.not_indexable loc ut (List.length indices) |> error in
aux ut (List.map ~f:indexing_type indices)
let inferred_ad_type_of_indexed at ut uindices =
UnsizedType.fill_adtype_for_type
(* correctness: index expressions only contain int types, so lub_ad_tupe
should never be [None]. *)
(UnsizedType.lub_ad_type
(at
:: List.map
~f:(function
| All -> UnsizedType.DataOnly
| Single ue1 | Upfrom ue1 | Downfrom ue1 -> ue1.emeta.ad_level
| Between (ue1, ue2) ->
UnsizedType.lub_ad_type
[ue1.emeta.ad_level; ue2.emeta.ad_level]
|> Option.value_exn)
uindices)
|> Option.value_exn)
ut
(* function checking *)
let verify_conddist_name loc id =
if
List.exists
~f:(fun x -> String.is_suffix id.name ~suffix:x)
Utils.conditioning_suffices
then ()
else Semantic_error.conditional_notation_not_allowed loc |> error
let verify_fn_conditioning loc id =
if
List.exists
~f:(fun suffix -> String.is_suffix id.name ~suffix)
Utils.conditioning_suffices
then Semantic_error.conditioning_required loc |> error
(** `Target+=` can only be used in model and functions with right suffix (same
for tilde etc) *)
let verify_fn_target_plus_equals cf loc id =
if String.is_suffix id.name ~suffix:"_lp" then
if cf.current_block = TParam then
add_warning loc
(* resolve https://github.com/stan-dev/stanc3/issues/1482 before
removal *)
"Using _lp functions in transformed parameters is deprecated and will \
be disallowed in Stan 2.40. Use an _jacobian function instead, as \
this allows change of variable adjustments which are conditionally \
enabled by the algorithms."
else if in_lp_function cf || cf.current_block = Model then ()
else Semantic_error.target_plusequals_outside_model_or_logprob loc |> error
let verify_fn_jacobian_plus_equals cf loc tenv id args =
if String.is_suffix id.name ~suffix:"_jacobian" then
if not (in_jacobian_function cf || cf.current_block = TParam) then
Semantic_error.jacobian_plusequals_not_allowed loc |> error
else if
not
(List.exists args ~f:(fun e ->
UnsizedType.is_autodifftype e.emeta.ad_level))
then
let alt =
String.chop_suffix_exn ~suffix:"_jacobian" id.name ^ "_constrain" in
let message =
"Calling a _jacobian function without any parameter arguments still \
applies the Jacobian adjustments, ensure this is intentional!"
^
if Env.mem tenv alt then " Consider using " ^ alt ^ " instead." else ""
in
warnings := (loc, message) :: !warnings
(** Rng functions cannot be used in Tp or Model and only in function defs with
the right suffix *)
let verify_fn_rng cf loc id =
if String.is_suffix id.name ~suffix:"_rng" then
if cf.in_toplevel_decl then Semantic_error.invalid_decl_rng_fn loc |> error
else if
not
(in_rng_function cf || cf.current_block = GQuant
|| cf.current_block = TData)
then Semantic_error.invalid_rng_fn loc |> error
(** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs
or the model block *)
let verify_unnormalized cf loc id =
if
Utils.is_unnormalized_distribution id.name
&& not (in_udf_distribution cf || cf.current_block = Model)
then Semantic_error.invalid_unnormalized_fn loc |> error
let mk_fun_app ~is_cond_dist ~loc kind name args ~type_ : Ast.typed_expression =
let fn =
if is_cond_dist then CondDistApp (kind, name, args)
else FunApp (kind, name, args) in
let ad_type =
if UnsizedType.is_discrete_type type_ then UnsizedType.DataOnly
else if
UnsizedType.any_autodiff (List.map ~f:(fun x -> x.emeta.ad_level) args)
then AutoDiffable
else DataOnly in
mk_typed_expression ~expr:fn ~loc ~type_
~ad_level:(UnsizedType.fill_adtype_for_type ad_type type_)
let check_normal_fn ~is_cond_dist loc tenv id es =
match Env.find tenv (Utils.normalized_name id.name) with
| {kind= `Variable {location= prev; _}; _} :: _
(* variables can sometimes shadow stanlib functions, so we have to check
this *)
when not
(Stan_math_signatures.is_stan_math_function_name
(Utils.normalized_name id.name)) ->
Semantic_error.returning_fn_expected_nonfn_found loc id.name (Some prev)
|> error
| [] ->
(match Utils.split_distribution_suffix id.name with
| Some (prefix, suffix) -> (
let is_known_family s =
List.Assoc.mem Stan_math_signatures.distributions s
~equal:String.equal in
match suffix with
| ("lpmf" | "lupmf") when Env.mem tenv (prefix ^ "_lpdf") ->
Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc
(prefix, suffix)
| ("lpdf" | "lupdf") when Env.mem tenv (prefix ^ "_lpmf") ->
Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc
(prefix, suffix)
| _ ->
if
is_known_family prefix
&& List.mem ~equal:String.equal
Utils.cumulative_distribution_suffices_w_rng suffix
then
Semantic_error
.returning_fn_expected_undeclared_dist_suffix_found loc
(prefix, suffix)
else
Semantic_error.returning_fn_expected_undeclaredident_found loc
id.name
(Env.nearest_ident tenv id.name))
| None ->
Semantic_error.returning_fn_expected_undeclaredident_found loc
id.name
(Env.nearest_ident tenv id.name))
|> error
| _ (* a function *) -> (
(* NB: At present, [SignatureMismatch.matching_function] cannot handle
overloaded function types. This is not needed until UDFs can be
higher-order, as it is special cased for variadic functions *)
match
SignatureMismatch.matching_function tenv id.name (get_arg_types es)
with
| UniqueMatch (Void, _, _, prev) ->
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
prev
|> error
| UniqueMatch (ReturnType ut, fnk, promotions, _) ->
mk_fun_app ~is_cond_dist ~loc
(fnk (Fun_kind.suffix_from_name id.name))
id
(Promotion.promote_list es promotions)
~type_:ut
| AmbiguousMatch sigs ->
Semantic_error.ambiguous_function_promotion loc id.name
(Some (List.map ~f:type_of_expr_typed es))
sigs
|> error
| SignatureErrors (l, b) ->
es
|> List.map ~f:(fun e -> e.emeta.type_)
|> Semantic_error.illtyped_fn_app loc id.name (l, b)
|> error)
(** Given a constraint function [matches], find any signature which exists
Returns the first [Ok] if any exist, or else [Error] *)
let find_matching_first_order_fn tenv matches fname =
let candidates =
Utils.stdlib_distribution_name fname.name
|> Env.find tenv |> List.map ~f:matches in
let ok, errs = List.partition_map candidates ~f:Result.to_either in
match SignatureMismatch.unique_minimum_promotion ok with
| Ok a -> SignatureMismatch.UniqueMatch a
| Error (Some tys) ->
List.filter_map tys ~f:(fun (ty, loc) ->
match ty with
| UnsizedType.UFun (args, rt, _, _) -> Some (rt, args, loc)
| _ -> None)
|> AmbiguousMatch
| Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs)
let make_function_variable cf loc id = function
| UnsizedType.UFun (args, rt, (FnLpdf _ | FnLpmf _), mem_pattern) ->
let type_ =
UnsizedType.UFun
(args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in
mk_typed_expression ~expr:(Variable id)
~ad_level:(calculate_autodifftype cf Functions type_)
~type_ ~loc
| UnsizedType.UFun _ as type_ ->
mk_typed_expression ~expr:(Variable id)
~ad_level:(calculate_autodifftype cf Functions type_)
~type_ ~loc
| type_ ->
Common.ICE.internal_compiler_error
[%message
"Attempting to create function variable out of "
(type_ : UnsizedType.t)]
(** Check that the functions in the list [requires_higher_order_autodiff] cannot
{b transitively} call stan math functions that don't have second derivative
support.
Note that this does not re-do overload resolution, so it implements a
correct-but-overly-conservative check. *)
let verify_second_order_derivative_compatibility (ast : typed_program) =
let get_function_bodies fn_name =
List.concat_map (Ast.get_stmts ast.functionblock) ~f:(fun s ->
match s.stmt with
| FunDef {funname= {name; _}; body= {stmt= Block b; _}; _}
when String.equal name fn_name ->
b
| _ -> []) in
let rec check_fun (visited : String.Set.t) {name= fn_name; id_loc} =
if Set.mem visited fn_name then visited
else
let rec check_expr seen = function
| {expr= FunApp (StanLib _, {name; id_loc= call_loc}, _); _}
when Stan_math_signatures.lacks_higher_order_autodiff name ->
(* note: we could possibly check all the arguments are DataOnly and
still allow it, but those seem like mostly useless cases. *)
Semantic_error.laplace_compatibility id_loc name call_loc |> error
| {expr= FunApp (UserDefined _, name, es); _} ->
(* we want the location to be the use-site no matter what *)
let seen' = check_fun seen {name with id_loc} in
List.fold ~f:check_expr ~init:seen' es
| {expr= Variable name; emeta= {type_= UFun _; _}} ->
check_fun seen {name with id_loc}
| e -> Ast.fold_expression check_expr seen e.expr in
let check_lval acc l = fold_lval_with check_expr acc l in
let rec check_stmt seen s =
match s.stmt with
| NRFunApp (UserDefined _, name, es) ->
let seen' = check_fun seen {name with id_loc} in
List.fold ~f:check_expr ~init:seen' es
| stmt -> Ast.fold_statement check_expr check_stmt check_lval seen stmt
in
let visited' = Set.add visited fn_name in
List.fold ~f:check_stmt ~init:visited' (get_function_bodies fn_name) in
ignore
(List.fold ~f:check_fun ~init:String.Set.empty
!requires_higher_order_autodiff)
(** Currently only used by the laplace functions, this checks that a function in
[tenv] called [fname] can be invoked with the arguments from [arg_tupl]. *)
let check_function_callable_with_tuple cf tenv caller_id fname
?(required_args = []) arg_tupl required_fn_return_type =
let arg_types =
check_texpression_is_tuple arg_tupl
(Printf.sprintf "Forwarded arguments to '%s' in call to '%s'" fname.name
caller_id.name) in
let required_arg_names, required_arg_types = List.unzip required_args in
let required = required_arg_types @ arg_types in
let matches info =
let location = Env.location info in
match info with
| Env.{type_= UnsizedType.UFun (args, return_type, sfx, _) as fn_type; _} ->
let open SignatureMismatch in
let open Common.Let_syntax.Result in
if return_type <> required_fn_return_type then
Error
(`FnRequirementsError
( ReturnTypeMismatch (required_fn_return_type, return_type)
, location ))
else if sfx <> FnPlain then
Error
(`FnRequirementsError
( SuffixMismatch (FnPlain, Fun_kind.forget_normalization sfx)
, location ))
else
let no_prom_args, _ =
List.split_n args (List.length required_arg_types) in
let* () =
(let* () =
check_compatible_arguments_no_promotion required_arg_types
no_prom_args in
(* checking both ways around as this is the best way to catch
DataOnly misspecifications for these arguments *)
check_compatible_arguments_no_promotion no_prom_args
required_arg_types)
|> Result.map_error ~f:(fun x ->
`FnRequirementsError (InputMismatch x, location)) in
let+ promotions =
check_compatible_arguments_mod_conv args required
|> Result.map_error ~f:(fun x ->
`SuppliedArgsMismatch (InputMismatch x, location)) in
((fn_type, location), promotions)
| _ -> Error (`NonFunction location) in
match find_matching_first_order_fn tenv matches fname with
| SignatureMismatch.UniqueMatch ((ftype, _), promotions) ->
let fn = make_function_variable cf fname.id_loc fname ftype in
let args =
Promotion.promote arg_tupl
(Promotion.TuplePromotion
(snd @@ List.(split_n promotions (length required_arg_types))))
in
(fn, args)
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion fname.id_loc fname.name None
ps
|> error
| SignatureErrors (`NonFunction prev) ->
Semantic_error.returning_fn_expected_nonfn_found fname.id_loc fname.name
prev
|> error
| SignatureErrors (`FnRequirementsError (details, prev)) ->
Semantic_error.forwarded_function_signature_error fname.id_loc
caller_id.name fname.name details prev
|> error
| SignatureErrors (`SuppliedArgsMismatch (details, prev)) ->
Semantic_error.forwarded_function_application_error arg_tupl.emeta.loc
caller_id.name fname.name required_arg_names details prev
|> error
let specialize_loc ~loc err (args : Ast.typed_expression list) =
let which_arg =
match err with
| SignatureMismatch.ArgError (i, _) -> i - 1
| ArgNumMismatch (expected, found) ->
if expected > found then found - 1 else expected in
List.nth args which_arg
|> Option.value_map ~f:(fun e -> e.emeta.loc) ~default:loc
let verify_laplace_control_args loc id (args : typed_expression list) =
match (String.is_substring ~substring:"_tol" id.name, args) with
| false, [] -> ()
| true, [arg] -> (
let arg_tys =
check_texpression_is_tuple arg
("Control arguments for '" ^ id.name ^ "'") in
match
SignatureMismatch.check_compatible_arguments_mod_conv
Stan_math_signatures.laplace_tolerance_argument_types arg_tys
with
| Ok _ -> ()
| Error err ->
let loc =
specialize_loc ~loc err
(match arg.expr with TupleExpr elts -> elts | _ -> []) in
Semantic_error.illtyped_laplace_tolerance_args loc id.name err
|> error)
| true, [] ->
Semantic_error.illtyped_laplace_tolerance_args loc id.name
(SignatureMismatch.check_compatible_arguments_mod_conv
Stan_math_signatures.laplace_tolerance_argument_types []
|> Result.error |> Option.value_exn)
|> error
| true, _ :: a :: _ ->
Semantic_error.illtyped_laplace_extra_args a.emeta.loc id.name
(List.length args - 1)
|> error
| false, a :: _ ->
Semantic_error.illtyped_laplace_extra_args a.emeta.loc id.name
(List.length args)
|> error
let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list)
=
if Stan_math_signatures.is_stan_math_variadic_function_name id.name then
check_variadic ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_reduce_sum_fn id.name then
check_reduce_sum ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_embedded_laplace_fn id.name then
check_laplace_fn ~is_cond_dist loc cf tenv id tes
else check_normal_fn ~is_cond_dist loc tenv id tes
(** Reduce sum is a special case, even compared to the other variadic functions,
because it is polymorphic in the type of the first argument. The first,
fourth, and fifth arguments must agree, which is too complicated to be
captured declaratively. *)
and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
let basic_mismatch () =
let mandatory_args =
UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in
let mandatory_fun_args =
UnsizedType.
[(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in
SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args
mandatory_fun_args None UReal (get_arg_types tes) in
let matching remaining_es fn =
match fn with
| Env.{type_= UnsizedType.UFun (sliced_arg_fun :: _, _, _, _) as ftype; _}
->
let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in
let mandatory_fun_args =
[sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args
mandatory_fun_args (Env.location fn) UReal arg_types
| _ -> basic_mismatch () in
match tes with
| {expr= Variable fname; _}
:: ({emeta= {type_= slice_type; loc= slice_loc; _}; _} :: _ as remaining_es)
-> (
let slice_type, n = UnsizedType.unwind_array_type slice_type in
if n = 0 then
Semantic_error.illtyped_reduce_sum_not_array slice_loc slice_type
|> error
else if
not
@@ List.mem Stan_math_signatures.reduce_sum_slice_types slice_type
~equal:( = )
then
Semantic_error.illtyped_reduce_sum_slice slice_loc slice_type |> error;
match find_matching_first_order_fn tenv (matching remaining_es) fname with
| SignatureMismatch.UniqueMatch ((ftype, _), promotions) ->
(* a valid signature exists *)
let tes = make_function_variable cf loc fname ftype :: remaining_es in
mk_fun_app ~is_cond_dist ~loc (StanLib FnPlain) id
(Promotion.promote_list tes promotions)
~type_:UnsizedType.UReal
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion loc fname.name None ps
|> error
| SignatureErrors (expected_args, err, prev) ->
let loc = specialize_loc ~loc err tes in
Semantic_error.illtyped_reduce_sum loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err prev
|> error)
| _ ->
let expected_args, err, prev =
basic_mismatch () |> Result.error |> Option.value_exn in
let loc = specialize_loc ~loc err tes in
Semantic_error.illtyped_reduce_sum loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err prev
|> error
(** Laplace functions are also special, in two ways:
+ The general forms accept two UDF callbacks, not just one
+ The callback arguments are passed as a tuple, not just as trailing
arguments. (this is a requirement of the first point, but worth noting
separately). *)
and check_laplace_fn ~is_cond_dist loc cf tenv id tes =
(* generic failure message for if the wrong number of argument is passed,
preventing us from pattern matching and doing a better job *)
let generic_failure ?(early = false) () =
Semantic_error.illtyped_laplace_generic loc id.name early
(List.map ~f:arg_type tes)
|> error in
let lik_args, rest =
(* Check the arguments for the likihood: For helpers, this is
required_prob_args, otherwise a callback and args *)
let required_prob_args =
Stan_math_signatures.laplace_helper_param_types id.name in
if not @@ List.is_empty required_prob_args then
(* helper argument check *)
let for_prob, rest = List.split_n tes (List.length required_prob_args) in
match
SignatureMismatch.check_compatible_arguments_mod_conv required_prob_args
(get_arg_types for_prob)
with
| Ok promotions -> (Promotion.promote_list for_prob promotions, rest)
| Error err ->
let loc = specialize_loc ~loc err tes in
Semantic_error.illtyped_laplace_helper_args loc id.name
required_prob_args (SignatureMismatch.InputMismatch err)
|> error
else
(* likelihood callback check *)
match tes with
| {expr= Variable lik_fun; _} :: lik_tupl :: tes ->
let lik_fun, lik_tupl =
(* adds the function name to the global list that is checked
later *)
needs_higher_order_autodiff lik_fun;
check_function_callable_with_tuple cf tenv id lik_fun lik_tupl
~required_args:
[("latent gaussian vector", (AutoDiffable, UVector))]
(UnsizedType.ReturnType UReal) in
([lik_fun; lik_tupl], tes)
| _ -> generic_failure ~early:true () in
(* check hessian block size *)
let hbs_arg, rest =
let loc =
match List.last lik_args with
| Some e -> {e.emeta.loc with begin_loc= e.emeta.loc.end_loc}
| None -> loc in
match rest with
| hbs :: rest ->
let hbs_ty = arg_type hbs in
if hbs_ty <> UnsizedType.(DataOnly, UInt) then
Semantic_error.illtyped_laplace_hessian_block_size_arg hbs.emeta.loc
id.name (Some hbs_ty)
|> error
else (hbs, rest)
| _ ->
Semantic_error.illtyped_laplace_hessian_block_size_arg loc id.name None
|> error in
(* Check the remaining arguments: initial guess, covariance, and tolerances *)
match rest with
| {expr= Variable cov_fun; _} :: cov_tupl :: control_args ->
let cov_fun_type, cov_tupl =
check_function_callable_with_tuple cf tenv id cov_fun cov_tupl
(UnsizedType.ReturnType UMatrix) in
(* note for future: pred_rng typechecking would need to look at second
tuple for the training prediction and test prediction. This would
probably require two more calls to
[check_function_callable_with_tuple] *)
verify_laplace_control_args loc id control_args;
let args =
lik_args @ (hbs_arg :: cov_fun_type :: cov_tupl :: control_args) in
let return_type =
if String.is_suffix id.name ~suffix:"_rng" then UnsizedType.UVector
else UnsizedType.UReal in
mk_fun_app ~is_cond_dist ~loc
(StanLib (Fun_kind.suffix_from_name id.name))
id args ~type_:return_type
| _ -> generic_failure ()
and check_variadic ~is_cond_dist loc cf tenv id tes =
let UnsizedType.{control_args; required_fn_args; required_fn_rt; return_type}
=
Stan_math_signatures.lookup_stan_math_variadic_function id.name
|> Option.value_exn in
let matching remaining_es (Env.{type_= ftype; _} as info) =
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args
required_fn_args (Env.location info) required_fn_rt arg_types in
match tes with
| {expr= Variable fname; _} :: remaining_es -> (
match find_matching_first_order_fn tenv (matching remaining_es) fname with
| SignatureMismatch.UniqueMatch ((ftype, _), promotions) ->
let tes = make_function_variable cf loc fname ftype :: remaining_es in
mk_fun_app ~is_cond_dist ~loc (StanLib FnPlain) id
(Promotion.promote_list tes promotions)
~type_:return_type
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion loc fname.name None ps
|> error
| SignatureErrors (expected_args, err, prev) ->
let loc = specialize_loc ~loc err tes in
Semantic_error.illtyped_variadic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args required_fn_rt err prev
|> error)
| _ ->
let expected_args, err, prev =
SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args
required_fn_args None required_fn_rt (get_arg_types tes)
|> Result.error |> Option.value_exn in
let loc = specialize_loc ~loc err tes in
Semantic_error.illtyped_variadic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args required_fn_rt err prev
|> error
and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) =
let name_check =
if is_cond_dist then verify_conddist_name else verify_fn_conditioning in
let res = check_fn ~is_cond_dist loc cf tenv id es in
verify_identifier id;
name_check loc id;
verify_fn_target_plus_equals cf loc id;
verify_fn_jacobian_plus_equals cf loc tenv id es;
verify_fn_rng cf loc id;
verify_unnormalized cf loc id;
res
and check_indexed loc cf tenv e indices =
let tindices = List.map ~f:(check_index cf tenv) indices in
let te = check_expression cf tenv e in
let type_ = inferred_unsizedtype_of_indexed ~loc te.emeta.type_ tindices in
let ad_level = inferred_ad_type_of_indexed te.emeta.ad_level type_ tindices in
mk_typed_expression ~expr:(Indexed (te, tindices)) ~ad_level ~type_ ~loc
and check_index cf tenv = function
| All -> All
(* Check that indexes have int (container) type *)
| Single e ->
let te = check_expression cf tenv e in
if has_int_type te || has_int_array_type te then Single te
else
Semantic_error.int_intarray_or_range_expected te.emeta.loc
te.emeta.type_