Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,14 @@ let rec modify_stmt_pattern
; initialize=
Assign
({ pattern= FunApp (CompilerInternal (FnReadParam read_param), args)
; _ } as assigner) } ->
; _ } as assigner)
; decl_annotations } ->
let name = decl_id in
if Set.mem modifiable_set name then
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_annotations
; decl_type=
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type)
; initialize=
Expand All @@ -582,6 +584,7 @@ let rec modify_stmt_pattern
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_annotations
; decl_type=
Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type)
; initialize=
Expand Down Expand Up @@ -682,15 +685,18 @@ and modify_stmt (Stmt.Fixed.{pattern; _} as stmt)

let collect_mem_pattern_variables stmts =
let take_stmt acc = function
| Stmt.Fixed.{pattern= Decl {decl_id; decl_type= Type.Sized stype; _}; _}
| Stmt.Fixed.
{ pattern=
Decl {decl_id; decl_type= Type.Sized stype; decl_annotations; _}
; meta }
when SizedType.has_mem_pattern stype ->
(decl_id, stype) :: acc
(decl_id, meta, stype, decl_annotations) :: acc
| _ -> acc in
Mir_utils.fold_stmts ~take_expr:(fun acc _ -> acc) ~take_stmt ~init:[] stmts
|> List.rev

let pp_mem_patterns ppf (Program.{reverse_mode_log_prob; _} : Program.Typed.t) =
let pp_var ppf (name, stype) =
let pp_var ppf (name, _, stype, _) =
Fmt.pf ppf "%a %s: %a"
(SizedType.pp Expr.Typed.pp)
stype name Middle.Mem_pattern.pp
Expand All @@ -699,3 +705,22 @@ let pp_mem_patterns ppf (Program.{reverse_mode_log_prob; _} : Program.Typed.t) =
(* Collect all the sizedtypes which have a mem pattern *)
collect_mem_pattern_variables reverse_mode_log_prob in
Fmt.(pf ppf "@[<v>%a@.@]" (list pp_var)) mem_vars

let check_annotations (Program.{reverse_mode_log_prob; _} : Program.Typed.t) =
let mem_vars = collect_mem_pattern_variables reverse_mode_log_prob in
List.filter_map
~f:(fun (name, loc, stype, annotations) ->
match annotations with
| [] -> None
| _ ->
if
List.exists ~f:(fun x -> x = "debug_soa") annotations
&& SizedType.get_mem_pattern stype <> Mem_pattern.SoA
then
Some
( loc
, "Variable '" ^ name
^ "' was marked with '@debug_soa' but is not SoA after \
optimization." )
else None)
mem_vars
29 changes: 24 additions & 5 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,18 @@ let gen_inline_var (name : string) (id_var : string) =

let replace_fresh_local_vars (fname : string) stmt =
let f (m : (string, string) Core.Map.Poly.t) = function
| Stmt.Fixed.Pattern.Decl {decl_adtype; decl_type; decl_id; initialize} ->
| Stmt.Fixed.Pattern.Decl
{decl_adtype; decl_type; decl_id; decl_annotations; initialize} ->
let new_name =
match Map.Poly.find m decl_id with
| Some existing -> existing
| None -> gen_inline_var fname decl_id in
( Stmt.Fixed.Pattern.Decl
{decl_adtype; decl_id= new_name; decl_type; initialize}
{ decl_adtype
; decl_id= new_name
; decl_type
; decl_annotations
; initialize }
, Map.Poly.set m ~key:decl_id ~data:new_name )
| Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} ->
let new_name =
Expand Down Expand Up @@ -201,6 +206,7 @@ let handle_early_returns (fname : string) opt_var stmt =
{ decl_adtype= DataOnly
; decl_id= returned
; decl_type= Sized SInt
; decl_annotations= []
; initialize= Default }
; meta= Location_span.empty }
; Stmt.Fixed.
Expand Down Expand Up @@ -306,6 +312,7 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
(Type.to_unsized decl_type)
; decl_id= inline_return_name
; decl_type
; decl_annotations= []
; initialize= Uninit } ]
(* We should minimize the code that's having its variables
replaced to avoid conflict with the (two) new dummy
Expand Down Expand Up @@ -476,10 +483,20 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} =
Block (List.map l ~f:(inline_function_statement propto adt fim))
| SList l ->
SList (List.map l ~f:(inline_function_statement propto adt fim))
| Decl {decl_adtype; decl_id; decl_type; initialize= Assign expr} ->
| Decl
{ decl_adtype
; decl_id
; decl_type
; decl_annotations
; initialize= Assign expr } ->
let d, s, e = inline_function_expression propto adt fim expr in
slist_concat_no_loc (d @ s)
(Decl {decl_adtype; decl_id; decl_type; initialize= Assign e})
(Decl
{ decl_adtype
; decl_id
; decl_type
; decl_annotations
; initialize= Assign e })
| Decl r -> Decl r
| Skip -> Skip
| Break -> Break
Expand Down Expand Up @@ -993,6 +1010,7 @@ let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) =
{ decl_adtype= Expr.Typed.adlevel_of key
; decl_id= data
; decl_type= Type.Unsized (Expr.Typed.type_of key)
; decl_annotations= [] (* TODO annotations: correct? *)
; initialize= Default }
; meta= Location_span.empty }
:: accum) in
Expand Down Expand Up @@ -1375,4 +1393,5 @@ let optimization_suite ?(settings = all_optimizations) mir =
let optimizations =
List.filter_map maybe_optimizations ~f:(fun (fn, flag) ->
if flag then Some fn else None) in
List.fold optimizations ~init:mir ~f:(fun mir opt -> opt mir)
let mir = List.fold optimizations ~init:mir ~f:(fun mir opt -> opt mir) in
mir
5 changes: 5 additions & 0 deletions src/driver/Entry.ml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ let stan2cpp model_name model (flags : Flags.t) (output : other_output -> unit)
output (Warnings parser_warnings);
let* result =
let* ast = ast in
let unused_annotations =
Frontend.Annotations.find_unrecognized
Stan_math_backend.Annotations.recognized_annotation ast in
output (Warnings unused_annotations);
if flags.debug_settings.print_ast then
output (DebugOutput (fmt_sexp [%sexp (ast : Ast.untyped_program)]));
let+ typed_ast, type_warnings =
Expand Down Expand Up @@ -132,6 +136,7 @@ let stan2cpp model_name model (flags : Flags.t) (output : other_output -> unit)
Optimize.optimization_suite
~settings:(Flags.get_optimization_settings flags)
tx_mir in
output (Warnings (Memory_patterns.check_annotations opt_mir));
if flags.debug_settings.print_mem_patterns then
output
(Memory_patterns (Fmt.str "%a" Memory_patterns.pp_mem_patterns opt_mir));
Expand Down
40 changes: 40 additions & 0 deletions src/frontend/Annotations.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
open Ast
open Core
open Middle

let unknown a = "Unknown annotation '" ^ a ^ "' will be ignored by the compiler"

let unsupported a t =
"Unsupported annotation '" ^ a ^ "' for type " ^ t
^ ". This will be ignored by the compiler"

let unused_list pred id ty anns =
List.filter_map
~f:(fun a ->
match pred a ty with
| `Fine -> None
| `Unknown -> Some (id.id_loc, unknown a)
| `WrongType ->
Some (id.id_loc, unsupported a (Fmt.to_to_string UnsizedType.pp ty)))
anns

let rec collect_stmt pred (acc : Warnings.t list) {stmt; _} =
match stmt with
| FunDef {annotations; funname; returntype; arguments; _} ->
let args = List.map ~f:(fun (ad, ty, _) -> (ad, ty)) arguments in
let ty =
UnsizedType.UFun
( args
, returntype
, Fun_kind.suffix_from_name funname.name
, Mem_pattern.AoS ) in
acc @ (unused_list pred funname ty) annotations
| VarDecl {annotations; variables; decl_type; _} ->
acc
@ (unused_list pred (List.hd_exn variables).identifier
(SizedType.to_unsized decl_type))
annotations
| _ -> fold_statement Fn.const (collect_stmt pred) Fn.const Fn.const acc stmt

let find_unrecognized pred (prog : Ast.untyped_program) : Warnings.t list =
fold_program (collect_stmt pred) [] prog
4 changes: 4 additions & 0 deletions src/frontend/Annotations.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
val find_unrecognized :
(string -> Middle.UnsizedType.t -> [`Fine | `Unknown | `WrongType])
-> Ast.untyped_program
-> Warnings.t list
2 changes: 2 additions & 0 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,15 @@ type ('e, 's, 'l, 'f) statement =
{ decl_type: 'e SizedType.t
; transformation: 'e Transformation.t
; is_global: bool
; annotations: string list
; variables: 'e variable list }
| FunDef of
{ returntype: UnsizedType.returntype
; funname: identifier
; arguments:
(Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * identifier)
list
; annotations: string list
; body: 's }
[@@deriving sexp, hash, compare, map, fold]

Expand Down
34 changes: 25 additions & 9 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ let var_constrain_check_stmts dconstrain loc adlevel decl_id decl_var trans
@ check_decl decl_var st trans loc adlevel
| _ -> []

let create_decl_with_assign decl_id declc decl_type initial_value transform
smeta =
let create_decl_with_assign decl_id declc decl_type initial_value
decl_annotations transform smeta =
let rhs = Option.map ~f:trans_expr initial_value in
let decl_adtype =
UnsizedType.fill_adtype_for_type declc.dadlevel (Type.to_unsized decl_type)
Expand All @@ -493,7 +493,12 @@ let create_decl_with_assign decl_id declc decl_type initial_value transform
let decl =
Stmt.
{ Fixed.pattern=
Decl {decl_adtype; decl_id; decl_type; initialize= Default}
Decl
{ decl_adtype
; decl_id
; decl_type
; decl_annotations
; initialize= Default }
; meta= smeta } in
let rhs_assignment =
Option.map
Expand Down Expand Up @@ -620,6 +625,7 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
{ decl_adtype= Expr.Typed.adlevel_of iteratee'
; decl_id= loopvar.name
; decl_type= Unsized decl_type
; decl_annotations= []
; initialize= Default } } in
let assignment var =
Stmt.Fixed.
Expand All @@ -635,15 +641,16 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
Common.ICE.internal_compiler_error
[%message
"Found function definition statement outside of function block"]
| Ast.VarDecl {decl_type; transformation; variables; is_global= _} ->
| Ast.VarDecl {decl_type; transformation; variables; annotations; is_global= _}
->
List.concat_map
~f:(fun {identifier; initial_value} ->
let transform = Transformation.map trans_expr transformation in
let decl_id = identifier.Ast.name in
let size_checks, dt = check_sizedtype decl_id decl_type in
size_checks
@ create_decl_with_assign decl_id declc dt initial_value transform
smeta)
@ create_decl_with_assign decl_id declc dt initial_value annotations
transform smeta)
variables
| Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap
| Ast.Profile (name, stmts) ->
Expand All @@ -666,6 +673,7 @@ and trans_packed_assign loc trans_stmt lvals rhs assign_op =
{ decl_adtype= rhs.emeta.ad_level
; decl_id= sym
; decl_type= Unsized rhs_type
; decl_annotations= []
; initialize= Uninit }
; meta= rhs.emeta.loc } in
let assign =
Expand Down Expand Up @@ -733,12 +741,13 @@ and trans_single_assignment smeta assign_lhs assign_rhs assign_op =

let trans_fun_def ud_dists (ts : Ast.typed_statement) =
match ts.stmt with
| Ast.FunDef {returntype; funname; arguments; body} ->
| Ast.FunDef {returntype; funname; arguments; annotations; body} ->
[ Program.
{ fdrt= returntype
; fdname= funname.name
; fdsuffix= Fun_kind.(suffix_from_name funname.name |> without_propto)
; fdargs= List.map ~f:trans_arg arguments
; fdannotations= annotations
; fdbody=
trans_stmt ud_dists
{transform_action= IgnoreTransform; dadlevel= AutoDiffable}
Expand Down Expand Up @@ -780,6 +789,7 @@ let rec trans_sizedtype_decl declc tr name st =
{ decl_type= Sized SInt
; decl_id
; decl_adtype= DataOnly
; decl_annotations= []
; initialize= Default }
; meta= e.meta.loc } in
let assign =
Expand Down Expand Up @@ -858,7 +868,12 @@ let trans_block ud_dists declc block prog =
let f stmt (accum1, accum2, accum3) =
match stmt with
| { Ast.stmt=
VarDecl {decl_type= type_; variables; transformation; is_global= true}
VarDecl
{ decl_type= type_
; variables
; transformation
; annotations
; is_global= true }
; smeta } ->
let outvars, sizes, stmts =
List.unzip3
Expand All @@ -877,10 +892,11 @@ let trans_block ud_dists declc block prog =
; out_unconstrained_st=
transform_sizedtype transform type_
; out_block= block
; out_annotations= annotations
; out_trans= transform } ) in
let stmts =
create_decl_with_assign decl_id declc (Sized type_)
initial_value transform smeta.loc in
initial_value annotations transform smeta.loc in
(outvar, size, stmts))
variables in
( outvars @ accum1
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ let parens_lval = map_lval_with no_parens Fn.id
let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement =
let stmt =
match stmt with
| VarDecl {decl_type= d; transformation= t; variables; is_global} ->
| VarDecl
{decl_type= d; transformation= t; variables; annotations; is_global} ->
VarDecl
{ decl_type= Middle.SizedType.map no_parens d
; transformation= Middle.Transformation.map keep_parens t
; variables= List.map ~f:(map_variable no_parens) variables
; annotations
; is_global }
| For {loop_variable; lower_bound; upper_bound; loop_body} ->
For
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Environment.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type info =
; kind:
[ `Variable of varinfo
| `UserDeclared of Location_span.t
| `UserExtern of Location_span.t
| `StanMath
| `UserDefined ] }

Expand Down
2 changes: 2 additions & 0 deletions src/frontend/Environment.mli
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type info =
; kind:
[ `Variable of varinfo
| `UserDeclared of Location_span.t
| `UserExtern of Location_span.t
| `StanMath
| `UserDefined ] }

Expand All @@ -37,6 +38,7 @@ val add :
-> string
-> Middle.UnsizedType.t
-> [ `UserDeclared of Location_span.t
| `UserExtern of Location_span.t
| `StanMath
| `UserDefined
| `Variable of varinfo ]
Expand Down
Loading