diff --git a/analysis/analysisError.ml b/analysis/analysisError.ml index 413d7647436..4e9e6710e7b 100644 --- a/analysis/analysisError.ml +++ b/analysis/analysisError.ml @@ -1228,8 +1228,10 @@ let rec messages ~concise ~signature location kind = in let actual = match actual with - | Group actual -> - Format.asprintf "type parameter group `[%a]`" Type.OrderedTypes.pp_concise actual + | VariadicExpression (Group actual) -> + Format.asprintf "type parameter group `[%a]`" Type.OrderedTypes.pp_concise_record actual + | VariadicExpression expression -> + Format.asprintf "type parameter group `%a`" Type.OrderedTypes.pp_concise expression | Single actual -> Format.asprintf "single type `%a`" Type.pp actual | CallableParameters actual -> Format.asprintf "callable parameters `%a`" Type.Callable.pp_parameters actual diff --git a/analysis/attributeResolution.ml b/analysis/attributeResolution.ml index eecf40c4ef7..f4d0e4fbe0d 100644 --- a/analysis/attributeResolution.ml +++ b/analysis/attributeResolution.ml @@ -887,22 +887,22 @@ class base class_metadata_environment dependency = else Type.Parameter.Single given, None | Unary _, CallableParameters _ - | Unary _, Type.Parameter.Group _ -> + | Unary _, Type.Parameter.VariadicExpression (Group _) -> ( Single Any, Some { name; kind = UnexpectedKind { expected = generic; actual = given } } ) | ListVariadic _, CallableParameters _ | ListVariadic _, Single _ -> - ( Group Any, + ( VariadicExpression (Group Any), Some { name; kind = UnexpectedKind { expected = generic; actual = given } } ) | ParameterVariadic _, Single _ - | ParameterVariadic _, Group _ -> + | ParameterVariadic _, VariadicExpression (Group _) -> ( CallableParameters Undefined, Some { name; kind = UnexpectedKind { expected = generic; actual = given } } ) | ParameterVariadic _, CallableParameters _ - | ListVariadic _, Group _ -> + | _, VariadicExpression _ -> (* TODO(T47346673): accept w/ new kind of validation *) given, None in @@ -920,7 +920,7 @@ class base class_metadata_environment dependency = name (List.map generics ~f:(function | Type.Variable.Unary _ -> Type.Parameter.Single Type.Any - | ListVariadic _ -> Group Any + | ListVariadic _ -> VariadicExpression (Group Any) | ParameterVariadic _ -> CallableParameters Undefined)), false ) in @@ -1740,6 +1740,7 @@ class base class_metadata_environment dependency = match generics with | [ListVariadic variable] -> let meta_generics = + let variable = Type.OrderedTypes.Concatenation.Middle.Variadic variable in Type.OrderedTypes.Concatenation.Middle.create ~variable ~mappers:["type"] |> Type.OrderedTypes.Concatenation.create in @@ -1764,7 +1765,7 @@ class base class_metadata_environment dependency = (* TODO:(T60536033) We'd really like to take FiniteList[Ts], but without that we can't actually return the correct metatype, which is a bummer *) - Type.Parameter.Group Any, Type.Any + Type.Parameter.VariadicExpression (Group Any), Type.Any | ParameterVariadic _ -> (* TODO:(T60536033) We'd really like to take FiniteList[Ts], but without that we can't actually return the correct metatype, which @@ -3155,10 +3156,12 @@ class base class_metadata_environment dependency = match kind with | SingleStar -> ( match resolved with - | Type.Tuple (Bounded ordered_types) -> Either.First ordered_types + | Type.Tuple (Bounded ordered_types) -> + Either.First (Type.OrderedTypes.Group ordered_types) (* We don't support expanding indefinite containers into ListVariadics *) | annotation -> Either.Second { expression; annotation } ) - | _ -> Either.First (Type.OrderedTypes.Concrete [resolved]) + | _ -> + Either.First (Type.OrderedTypes.Group (Type.OrderedTypes.Concrete [resolved])) in List.rev arguments |> List.partition_map ~f:extract in @@ -3172,7 +3175,7 @@ class base class_metadata_environment dependency = let concatenate extracted = let concatenated = match extracted with - | [] -> Some (Type.OrderedTypes.Concrete []) + | [] -> Some (Type.OrderedTypes.Group (Type.OrderedTypes.Concrete [])) | head :: tail -> let concatenate sofar next = sofar >>= fun left -> Type.OrderedTypes.concatenate ~left ~right:next @@ -3212,7 +3215,7 @@ class base class_metadata_environment dependency = match key, data with | Parameter.Variable (Concatenation concatenation), arguments -> bind_arguments_to_variadic - ~expected:(Type.OrderedTypes.Concatenation concatenation) + ~expected:(Type.OrderedTypes.Group (Type.OrderedTypes.Concatenation concatenation)) ~arguments | Parameter.Variable _, [] | Parameter.Keywords _, [] -> diff --git a/analysis/classHierarchy.ml b/analysis/classHierarchy.ml index 9b421116e1b..efe4ef89bf6 100644 --- a/analysis/classHierarchy.ml +++ b/analysis/classHierarchy.ml @@ -174,17 +174,17 @@ let immediate_parents (module Handler : Handler) class_name = let clean not_clean = - let open Type.OrderedTypes.Concatenation in List.map not_clean ~f:(function - | Type.Parameter.Single (Type.Variable variable) -> Some (Type.Variable.Unary variable) - | Group (Type.OrderedTypes.Concatenation concatenation) -> - unwrap_if_only_middle concatenation - >>= Middle.unwrap_if_bare - >>| fun variable -> Type.Variable.ListVariadic variable + | Type.Parameter.Single (Type.Variable variable) -> [Type.Variable.Unary variable] + | VariadicExpression expression -> + Type.Variable.GlobalTransforms.ListVariadic.collect_all + (Type.Parametric { name = ""; parameters = [VariadicExpression expression] }) + |> List.map ~f:(fun variable -> Type.Variable.ListVariadic variable) | CallableParameters (ParameterVariadicTypeVariable { head = []; variable }) -> - Some (ParameterVariadic variable) - | _ -> None) - |> Option.all + [ParameterVariadic variable] + | _ -> []) + |> List.concat + |> Option.some let variables ?(default = None) (module Handler : Handler) = function @@ -303,7 +303,7 @@ let instantiate_successors_parameters ((module Handler : Handler) as handler) ~s | Type.Bottom -> let to_any = function | Type.Variable.Unary _ -> Type.Parameter.Single Type.Any - | ListVariadic _ -> Group Any + | ListVariadic _ -> VariadicExpression (Group Any) | ParameterVariadic _ -> CallableParameters Undefined in index_of target @@ -315,12 +315,13 @@ let instantiate_successors_parameters ((module Handler : Handler) as handler) ~s let split = match Type.split source with | Primitive primitive, _ when not (contains handler primitive) -> None - | Primitive "tuple", [Type.Parameter.Group parameters] -> + | Primitive "tuple", [Type.Parameter.VariadicExpression (Group parameters)] -> Some ( "tuple", [ Type.Parameter.Single - (Type.weaken_literals (Type.OrderedTypes.union_upper_bound parameters)); + (Type.weaken_literals + (Type.OrderedTypes.union_upper_bound (Type.OrderedTypes.Group parameters))); ] ) | Primitive "tuple", [Type.Parameter.Single parameter] -> Some ("tuple", [Type.Parameter.Single (Type.weaken_literals parameter)]) @@ -351,24 +352,25 @@ let instantiate_successors_parameters ((module Handler : Handler) as handler) ~s | Type.Parameter.Single parameter, Type.Variable.Unary variable -> Type.Variable.UnaryPair (variable, parameter) | CallableParameters _, Unary variable - | Group _, Unary variable -> + | VariadicExpression _, Unary variable -> Type.Variable.UnaryPair (variable, Type.Any) - | Group parameter, ListVariadic variable -> - Type.Variable.ListVariadicPair (variable, parameter) + | VariadicExpression variadic_expression, ListVariadic variable -> + Type.Variable.ListVariadicPair (variable, variadic_expression) | CallableParameters _, ListVariadic variable | Single _, ListVariadic variable -> - Type.Variable.ListVariadicPair (variable, Any) + Type.Variable.ListVariadicPair (variable, Type.OrderedTypes.Group Any) | CallableParameters parameters, ParameterVariadic variable -> Type.Variable.ParameterVariadicPair (variable, parameters) | Single _, ParameterVariadic variable - | Group _, ParameterVariadic variable -> + | VariadicExpression _, ParameterVariadic variable -> Type.Variable.ParameterVariadicPair (variable, Undefined) in let replacement = let to_any = function | Type.Variable.Unary variable -> Type.Variable.UnaryPair (variable, Type.Any) | ListVariadic variable -> - Type.Variable.ListVariadicPair (variable, Type.OrderedTypes.Any) + Type.Variable.ListVariadicPair + (variable, Type.OrderedTypes.Group Type.OrderedTypes.Any) | ParameterVariadic variable -> Type.Variable.ParameterVariadicPair (variable, Undefined) in @@ -385,9 +387,16 @@ let instantiate_successors_parameters ((module Handler : Handler) as handler) ~s | Type.Parameter.Single single -> Type.Parameter.Single (TypeConstraints.Solution.instantiate replacement single) - | Group group -> - Group - (TypeConstraints.Solution.instantiate_ordered_types replacement group) + | VariadicExpression expression -> + let instantiated_expression = + Type.OrderedTypes.transform_variadic_expression + expression + ~f:(fun group -> + TypeConstraints.Solution.instantiate_ordered_types + replacement + (Type.OrderedTypes.Group group)) + in + VariadicExpression instantiated_expression | CallableParameters parameters -> CallableParameters (TypeConstraints.Solution.instantiate_callable_parameters diff --git a/analysis/constraintsSet.ml b/analysis/constraintsSet.ml index 01336c7191e..56d388bb425 100644 --- a/analysis/constraintsSet.ml +++ b/analysis/constraintsSet.ml @@ -155,7 +155,7 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct solve_ordered_types_less_or_equal order ~left ~right ~constraints in let concatenate left right = - left >>= fun left -> Type.OrderedTypes.concatenate ~left ~right + left >>= fun left -> Type.OrderedTypes.concatenate_record ~left ~right in List.map before_first_keyword ~f:extract_component |> Option.all @@ -595,22 +595,18 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct solve_less_or_equal order ~constraints ~left ~right) |> List.concat_map ~f:(fun constraints -> solve_less_or_equal order ~constraints ~left:right ~right:left) - | Group left_parameters, Group right_parameters, ListVariadic _ -> + | VariadicExpression left, VariadicExpression right, ListVariadic _ -> (* TODO(T47346673): currently all variadics are invariant, revisit this when we add variance *) constraints |> List.concat_map ~f:(fun constraints -> - solve_ordered_types_less_or_equal - order - ~constraints - ~left:left_parameters - ~right:right_parameters) + solve_variadic_expression_less_or_equal order ~constraints ~left ~right) |> List.concat_map ~f:(fun constraints -> - solve_ordered_types_less_or_equal + solve_variadic_expression_less_or_equal order ~constraints - ~left:right_parameters - ~right:left_parameters) + ~left:right + ~right:left) | CallableParameters left, CallableParameters right, ParameterVariadic _ -> let left = Type.Callable.create ~parameters:left ~annotation:Type.Any () in let right = Type.Callable.create ~parameters:right ~annotation:Type.Any () in @@ -666,7 +662,7 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct | Type.Tuple (Type.Unbounded left), Type.Tuple (Type.Unbounded right) -> solve_less_or_equal order ~constraints ~left ~right | Type.Tuple (Type.Bounded lefts), Type.Tuple (Type.Unbounded right) -> - let left = Type.OrderedTypes.union_upper_bound lefts in + let left = Type.OrderedTypes.union_upper_bound (Type.OrderedTypes.Group lefts) in solve_less_or_equal order ~constraints ~left ~right | Type.Tuple (Type.Bounded left), Type.Tuple (Type.Bounded right) -> solve_ordered_types_less_or_equal order ~left ~right ~constraints @@ -728,6 +724,27 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct solve_less_or_equal order ~constraints ~left:(Type.weaken_literals left) ~right + and solve_variadic_expression_less_or_equal order ~left ~right ~constraints = + let solve = solve_ordered_types_less_or_equal order ~constraints in + match left, right with + | Type.OrderedTypes.Group left, Type.OrderedTypes.Group right -> solve ~left ~right + | Broadcast _, Group right -> + let left = + Type.OrderedTypes.Concatenation + ( Type.OrderedTypes.Concatenation.Middle.create_bare (Expression left) + |> Type.OrderedTypes.Concatenation.create ) + in + solve ~left ~right + | Group left, Broadcast _ -> + let right = + Type.OrderedTypes.Concatenation + ( Type.OrderedTypes.Concatenation.Middle.create_bare (Expression right) + |> Type.OrderedTypes.Concatenation.create ) + in + solve ~left ~right + | _ -> impossible + + and solve_ordered_types_less_or_equal order ~left ~right ~constraints = let solve_concrete_against_concrete ~lefts ~rights constraints = let folded_constraints = @@ -743,118 +760,120 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct | List.Or_unequal_lengths.Unequal_lengths -> impossible in let solve_concrete_against_concatenation ~is_lower_bound ~bound ~concatenation = - let variable = Type.OrderedTypes.Concatenation.variable concatenation in - if Type.Variable.Variadic.List.is_free variable then - let handle_paired paired = - let left_and_right ~bound ~concatenated = - if is_lower_bound then bound, concatenated else concatenated, bound - in - let middle_vs_concrete ~concrete ~middle constraints = - let solve constraints = - match Type.OrderedTypes.Concatenation.Middle.unwrap_if_bare middle with - | Some variable -> - let add_bound = - if is_lower_bound then - OrderedConstraints.add_lower_bound - else - OrderedConstraints.add_upper_bound - in - add_bound - constraints - ~order - ~pair:(Type.Variable.ListVariadicPair (variable, Concrete concrete)) - |> Option.to_list - | None -> - (* Our strategy for solving Concrete[X0, X1, ... Xn] <: Map[mapper, mapped_var] - * is as follows: - * construct n "synthetic" unary type variables - * substitute them through the map, generating - * mapper[Synth0], mapper[Synth1], ... mapper[SynthN] - * pairwise solve the concrete memebers against the synthetics: - * X0 <: mapper[Synth0] && X1 <: mapper[Synth1] && ... Xn <: Mapper[SynthN] - * Solve the resulting constraints to Soln - * Add both upper and lower bounds on mapped_var to be - * Soln[Synth0], Soln[Synth1], ... Soln[SynthN] - *) - let synthetic_variables, synthetic_variable_constraints_set = - let namespace = Type.Variable.Namespace.create_fresh () in - let synthetic_solve index (synthetics_created_sofar, constraints_set) concrete = - let new_synthetic_variable = - Type.Variable.Unary.create (Int.to_string index) - |> Type.Variable.Unary.namespace ~namespace - in - let solve_against_concrete constraints = - let generated = - Type.OrderedTypes.Concatenation.Middle.singleton_replace_variable - middle - ~replacement:(Type.Variable new_synthetic_variable) + match Type.OrderedTypes.Concatenation.variable concatenation with + | Variadic variable when Type.Variable.Variadic.List.is_free variable -> + let handle_paired paired = + let left_and_right ~bound ~concatenated = + if is_lower_bound then bound, concatenated else concatenated, bound + in + let middle_vs_concrete ~concrete ~middle constraints = + let solve constraints = + match Type.OrderedTypes.Concatenation.Middle.unwrap_if_bare middle with + | Some (Variadic variable) -> + let add_bound = + if is_lower_bound then + OrderedConstraints.add_lower_bound + else + OrderedConstraints.add_upper_bound + in + add_bound + constraints + ~order + ~pair:(Type.Variable.ListVariadicPair (variable, Group (Concrete concrete))) + |> Option.to_list + | Some (Expression _) -> + impossible (*TODO: we can't resolve [L3,L4] <: BC[Ts1,Ts2] *) + | None -> + (* Our strategy for solving Concrete[X0, X1, ... Xn] <: Map[mapper, mapped_var] + * is as follows: + * construct n "synthetic" unary type variables + * substitute them through the map, generating + * mapper[Synth0], mapper[Synth1], ... mapper[SynthN] + * pairwise solve the concrete memebers against the synthetics: + * X0 <: mapper[Synth0] && X1 <: mapper[Synth1] && ... Xn <: Mapper[SynthN] + * Solve the resulting constraints to Soln + * Add both upper and lower bounds on mapped_var to be + * Soln[Synth0], Soln[Synth1], ... Soln[SynthN] + *) + let synthetic_variables, synthetic_variable_constraints_set = + let namespace = Type.Variable.Namespace.create_fresh () in + let synthetic_solve index (synthetics_created_sofar, constraints_set) concrete + = + let new_synthetic_variable = + Type.Variable.Unary.create (Int.to_string index) + |> Type.Variable.Unary.namespace ~namespace in - let left, right = - if is_lower_bound then - concrete, generated - else - generated, concrete + let solve_against_concrete constraints = + let generated = + Type.OrderedTypes.Concatenation.Middle.singleton_replace_variable + middle + ~replacement:(Type.Variable new_synthetic_variable) + in + let left, right = + if is_lower_bound then + concrete, generated + else + generated, concrete + in + solve_less_or_equal order ~constraints ~left ~right in - solve_less_or_equal order ~constraints ~left ~right + ( new_synthetic_variable :: synthetics_created_sofar, + List.concat_map constraints_set ~f:solve_against_concrete ) in - ( new_synthetic_variable :: synthetics_created_sofar, - List.concat_map constraints_set ~f:solve_against_concrete ) - in - List.foldi concrete ~f:synthetic_solve ~init:(impossible, [constraints]) - in - let consider_synthetic_variable_constraints synthetic_variable_constraints = - let instantiate_synthetic_variables solution = - List.map - synthetic_variables - ~f:(TypeConstraints.Solution.instantiate_single_variable solution) - |> Option.all + List.foldi concrete ~f:synthetic_solve ~init:(impossible, [constraints]) in - let add_bound concrete = - let add_bound ~adder constraints = - adder - constraints - ~order - ~pair:(Type.Variable.ListVariadicPair (variable, concrete)) + let consider_synthetic_variable_constraints synthetic_variable_constraints = + let instantiate_synthetic_variables solution = + List.map + synthetic_variables + ~f:(TypeConstraints.Solution.instantiate_single_variable solution) + |> Option.all in - add_bound ~adder:OrderedConstraints.add_lower_bound constraints - >>= add_bound ~adder:OrderedConstraints.add_upper_bound + let add_bound concrete = + let add_bound ~adder constraints = + adder + constraints + ~order + ~pair:(Type.Variable.ListVariadicPair (variable, concrete)) + in + add_bound ~adder:OrderedConstraints.add_lower_bound constraints + >>= add_bound ~adder:OrderedConstraints.add_upper_bound + in + OrderedConstraints.solve ~order synthetic_variable_constraints + >>= instantiate_synthetic_variables + >>| List.rev + >>| (fun substituted -> Type.Record.OrderedTypes.Group (Concrete substituted)) + >>= add_bound in - OrderedConstraints.solve ~order synthetic_variable_constraints - >>= instantiate_synthetic_variables - >>| List.rev - >>| (fun substituted -> Type.Record.OrderedTypes.Concrete substituted) - >>= add_bound - in - List.filter_map - synthetic_variable_constraints_set - ~f:consider_synthetic_variable_constraints + List.filter_map + synthetic_variable_constraints_set + ~f:consider_synthetic_variable_constraints + in + List.concat_map constraints ~f:solve in - List.concat_map constraints ~f:solve - in - let concrete_vs_concretes constraints ~pairs = - let solve_pair constraints (concatenated, bound) = - let left, right = left_and_right ~bound ~concatenated in - constraints - |> List.concat_map ~f:(fun constraints -> - solve_less_or_equal order ~constraints ~left ~right) + let concrete_vs_concretes constraints ~pairs = + let solve_pair constraints (concatenated, bound) = + let left, right = left_and_right ~bound ~concatenated in + constraints + |> List.concat_map ~f:(fun constraints -> + solve_less_or_equal order ~constraints ~left ~right) + in + List.fold ~init:constraints ~f:solve_pair pairs in - List.fold ~init:constraints ~f:solve_pair pairs + let middle, middle_bound = Type.OrderedTypes.Concatenation.middle paired in + concrete_vs_concretes ~pairs:(Type.OrderedTypes.Concatenation.head paired) [constraints] + |> middle_vs_concrete ~concrete:middle_bound ~middle + |> concrete_vs_concretes ~pairs:(Type.OrderedTypes.Concatenation.tail paired) in - let middle, middle_bound = Type.OrderedTypes.Concatenation.middle paired in - concrete_vs_concretes ~pairs:(Type.OrderedTypes.Concatenation.head paired) [constraints] - |> middle_vs_concrete ~concrete:middle_bound ~middle - |> concrete_vs_concretes ~pairs:(Type.OrderedTypes.Concatenation.tail paired) - in - Type.OrderedTypes.Concatenation.zip concatenation ~against:bound - >>| handle_paired - |> Option.value ~default:impossible - else - impossible + Type.OrderedTypes.Concatenation.zip concatenation ~against:bound + >>| handle_paired + |> Option.value ~default:impossible + | _ -> impossible in let open Type.OrderedTypes in let open Type.Variable.Variadic.List in match left, right with - | left, right when Type.OrderedTypes.equal left right -> [constraints] + | left, right when Type.OrderedTypes.equal_record_t left right -> [constraints] | Any, _ | _, Any -> [constraints] @@ -869,33 +888,35 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct >>= Type.OrderedTypes.Concatenation.Middle.unwrap_if_bare in match unwrap_if_only_variable left, unwrap_if_only_variable right with - | Some left_variable, Some right_variable + | Some (Variadic left_variable), Some (Variadic right_variable) when is_free left_variable && is_free right_variable -> (* Just as with unaries, we need to consider both possibilities *) let right_greater_than_left, left_less_than_right = ( OrderedConstraints.add_lower_bound constraints ~order - ~pair:(Type.Variable.ListVariadicPair (right_variable, Concatenation left)) + ~pair: + (Type.Variable.ListVariadicPair (right_variable, Group (Concatenation left))) |> Option.to_list, OrderedConstraints.add_upper_bound constraints ~order - ~pair:(Type.Variable.ListVariadicPair (left_variable, Concatenation right)) + ~pair: + (Type.Variable.ListVariadicPair (left_variable, Group (Concatenation right))) |> Option.to_list ) in right_greater_than_left @ left_less_than_right - | Some variable, _ when is_free variable -> + | Some (Variadic variable), _ when is_free variable -> OrderedConstraints.add_upper_bound constraints ~order - ~pair:(Type.Variable.ListVariadicPair (variable, Concatenation right)) + ~pair:(Type.Variable.ListVariadicPair (variable, Group (Concatenation right))) |> Option.to_list - | _, Some variable when is_free variable -> + | _, Some (Variadic variable) when is_free variable -> OrderedConstraints.add_lower_bound constraints ~order - ~pair:(Type.Variable.ListVariadicPair (variable, Concatenation left)) + ~pair:(Type.Variable.ListVariadicPair (variable, Group (Concatenation left))) |> Option.to_list | _ -> impossible ) @@ -936,10 +957,12 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct | Type.Variable.Unary variable -> Type.Parameter.Single (Type.Variable variable) | ListVariadic variable -> - Group - (Concatenation - ( Type.OrderedTypes.Concatenation.Middle.create_bare variable - |> Type.OrderedTypes.Concatenation.create )) + VariadicExpression + (Group + (Concatenation + ( Type.OrderedTypes.Concatenation.Middle.create_bare + (Type.OrderedTypes.Concatenation.Middle.Variadic variable) + |> Type.OrderedTypes.Concatenation.create ))) | ParameterVariadic variable -> CallableParameters (ParameterVariadicTypeVariable { head = []; variable })) in @@ -1040,11 +1063,11 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct | Type.Parameter.Single single -> Type.Parameter.Single (TypeConstraints.Solution.instantiate desanitization_solution single) - | Group group -> - Group + | VariadicExpression expression -> + VariadicExpression (TypeConstraints.Solution.instantiate_ordered_types desanitization_solution - group) + expression) | CallableParameters parameters -> CallableParameters (TypeConstraints.Solution.instantiate_callable_parameters @@ -1059,16 +1082,12 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct |> Option.value ~default:(Type.Variable variable) |> fun instantiated -> Type.Parameter.Single instantiated | ListVariadic variable -> - let default = - Type.OrderedTypes.Concatenation - ( Type.OrderedTypes.Concatenation.Middle.create_bare variable - |> Type.OrderedTypes.Concatenation.create ) - in + let default = Type.Variable.Variadic.List.self_reference variable in TypeConstraints.Solution.instantiate_single_list_variadic_variable solution variable |> Option.value ~default - |> fun instantiated -> Type.Parameter.Group instantiated + |> fun instantiated -> Type.Parameter.VariadicExpression instantiated | ParameterVariadic variable -> TypeConstraints.Solution.instantiate_single_parameter_variadic solution @@ -1100,7 +1119,7 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct solve_less_or_equal order ~constraints ~left ~right) | OrderedTypesLessOrEqual { left; right } -> List.concat_map existing_constraints ~f:(fun constraints -> - solve_ordered_types_less_or_equal order ~constraints ~left ~right) + solve_variadic_expression_less_or_equal order ~constraints ~left ~right) | VariableIsExactly pair -> let add_both_bounds constraints = OrderedConstraints.add_upper_bound constraints ~order ~pair diff --git a/analysis/test/annotatedBasesTest.ml b/analysis/test/annotatedBasesTest.ml index dd0977d46ac..a4c0a472e44 100644 --- a/analysis/test/annotatedBasesTest.ml +++ b/analysis/test/annotatedBasesTest.ml @@ -129,7 +129,7 @@ let test_inferred_generic_base context = (Type.parametric "typing.Generic" [ - Type.Parameter.Group + Type.Parameter.VariadicExpression (Type.Variable.Variadic.List.self_reference (Type.Variable.Variadic.List.create "test.Ts")); ]); diff --git a/analysis/test/annotatedSignatureTest.ml b/analysis/test/annotatedSignatureTest.ml index 47b019a1ba8..9c2a7605fcd 100644 --- a/analysis/test/annotatedSignatureTest.ml +++ b/analysis/test/annotatedSignatureTest.ml @@ -782,10 +782,8 @@ let test_unresolved_select context = (SignatureSelectionTypes.MismatchWithListVariadicTypeVariable { variable = - Concatenation - (Type.OrderedTypes.Concatenation.create - (Type.OrderedTypes.Concatenation.Middle.create_bare - (Type.Variable.Variadic.List.create "test.Ts"))); + Type.Variable.Variadic.List.self_reference + (Type.Variable.Variadic.List.create "test.Ts"); mismatch = NotDefiniteTuple { @@ -803,11 +801,9 @@ let test_unresolved_select context = (SignatureSelectionTypes.MismatchWithListVariadicTypeVariable { variable = - Concatenation - (Type.OrderedTypes.Concatenation.create - (Type.OrderedTypes.Concatenation.Middle.create_bare - (Type.Variable.Variadic.List.create "test.Ts"))); - mismatch = ConstraintFailure (Concrete [Type.float]); + Type.Variable.Variadic.List.self_reference + (Type.Variable.Variadic.List.create "test.Ts"); + mismatch = ConstraintFailure (Group (Concrete [Type.float])); }) )); assert_select "[pyre_extensions.type_variable_operators.Map[typing.List, Ts], int]" diff --git a/analysis/test/classHierarchyTest.ml b/analysis/test/classHierarchyTest.ml index 57229d8bee5..b3680749c06 100644 --- a/analysis/test/classHierarchyTest.ml +++ b/analysis/test/classHierarchyTest.ml @@ -412,13 +412,17 @@ let parametric_order = MockClassHierarchyHandler.handler parametric_order_base let variadic_order = let variadic = Type.Variable.Variadic.List.create "Ts" in - let simple_variadic = - [Type.Parameter.Group (Type.Variable.Variadic.List.self_reference variadic)] + let variadic_expression = + Type.Parameter.VariadicExpression (Type.Variable.Variadic.List.self_reference variadic) in let order = parametric_order_base in let open MockClassHierarchyHandler in insert order "UserTuple"; - connect order ~predecessor:"UserTuple" ~successor:"typing.Generic" ~parameters:simple_variadic; + connect + order + ~predecessor:"UserTuple" + ~successor:"typing.Generic" + ~parameters:[variadic_expression]; (* Contrived example *) connect @@ -432,65 +436,57 @@ let variadic_order = (Bounded (Concatenation (Type.OrderedTypes.Concatenation.create - (Type.OrderedTypes.Concatenation.Middle.create_bare variadic))))); + (Type.OrderedTypes.Concatenation.Middle.create_bare + (Type.OrderedTypes.Concatenation.Middle.Variadic variadic)))))); ]; insert order "SimpleTupleChild"; connect order ~predecessor:"SimpleTupleChild" ~successor:"typing.Generic" - ~parameters:simple_variadic; - connect order ~predecessor:"SimpleTupleChild" ~successor:"UserTuple" ~parameters:simple_variadic; + ~parameters:[variadic_expression]; + connect + order + ~predecessor:"SimpleTupleChild" + ~successor:"UserTuple" + ~parameters:[variadic_expression]; insert order "TupleOfLists"; - connect order ~predecessor:"TupleOfLists" ~successor:"typing.Generic" ~parameters:simple_variadic; + connect + order + ~predecessor:"TupleOfLists" + ~successor:"typing.Generic" + ~parameters:[variadic_expression]; connect order ~predecessor:"TupleOfLists" ~successor:"UserTuple" ~parameters: [ - Group - (Concatenation - (Type.OrderedTypes.Concatenation.create - (Type.OrderedTypes.Concatenation.Middle.create ~mappers:["list"] ~variable:variadic))); + VariadicExpression + (Group + (Concatenation + (Type.OrderedTypes.Concatenation.create + (Type.OrderedTypes.Concatenation.Middle.create + ~mappers:["list"] + ~variable:(Type.OrderedTypes.Concatenation.Middle.Variadic variadic))))); ]; insert order "DTypedTensor"; connect order ~predecessor:"DTypedTensor" ~successor:"typing.Generic" - ~parameters: - [ - Single (Type.Variable (Type.Variable.Unary.create "DType")); - Group - (Concatenation - (Type.OrderedTypes.Concatenation.create - (Type.OrderedTypes.Concatenation.Middle.create_bare variadic))); - ]; + ~parameters:[Single (Type.Variable (Type.Variable.Unary.create "DType")); variadic_expression]; insert order "IntTensor"; connect order ~predecessor:"IntTensor" ~successor:"typing.Generic" - ~parameters: - [ - Group - (Concatenation - (Type.OrderedTypes.Concatenation.create - (Type.OrderedTypes.Concatenation.Middle.create_bare variadic))); - ]; + ~parameters:[variadic_expression]; connect order ~predecessor:"IntTensor" ~successor:"DTypedTensor" - ~parameters: - [ - Single Type.integer; - Group - (Concatenation - (Type.OrderedTypes.Concatenation.create - (Type.OrderedTypes.Concatenation.Middle.create_bare variadic))); - ]; + ~parameters:[Single Type.integer; variadic_expression]; insert order "ClassParametricOnParamSpec"; connect order @@ -602,14 +598,14 @@ let test_instantiate_successors_parameters _ = in assert_equal (instantiate_successors_parameters variadic_order ~source:Type.Bottom ~target:"UserTuple") - (Some [Group Any]); + (Some [VariadicExpression (Group Any)]); assert_equal ~printer (instantiate_successors_parameters variadic_order ~source:(Type.parametric "SimpleTupleChild" ![Type.integer; Type.string; Type.bool]) ~target:"UserTuple") - (Some [Group (Concrete [Type.integer; Type.string; Type.bool])]); + (Some [VariadicExpression (Group (Concrete [Type.integer; Type.string; Type.bool]))]); assert_equal ~printer (instantiate_successors_parameters @@ -623,7 +619,11 @@ let test_instantiate_successors_parameters _ = variadic_order ~source:(Type.parametric "TupleOfLists" ![Type.integer; Type.string; Type.bool]) ~target:"UserTuple") - (Some [Group (Concrete [Type.list Type.integer; Type.list Type.string; Type.list Type.bool])]); + (Some + [ + VariadicExpression + (Group (Concrete [Type.list Type.integer; Type.list Type.string; Type.list Type.bool])); + ]); (* Concatenation *) assert_equal @@ -632,8 +632,14 @@ let test_instantiate_successors_parameters _ = variadic_order ~source:(Type.parametric "IntTensor" ![Type.literal_integer 4; Type.literal_integer 2]) ~target:"DTypedTensor") - (Some [Single Type.integer; Group (Concrete [Type.literal_integer 4; Type.literal_integer 2])]); - let list_variadic = Type.Variable.Variadic.List.create "Ts" in + (Some + [ + Single Type.integer; + VariadicExpression (Group (Concrete [Type.literal_integer 4; Type.literal_integer 2])); + ]); + let list_variadic = + Type.OrderedTypes.Concatenation.Middle.Variadic (Type.Variable.Variadic.List.create "Ts") + in assert_equal ~printer (instantiate_successors_parameters @@ -642,21 +648,23 @@ let test_instantiate_successors_parameters _ = (Type.parametric "IntTensor" [ - Group - (Type.OrderedTypes.Concatenation - (Type.OrderedTypes.Concatenation.create - ~tail:[Type.literal_integer 2] - (Type.OrderedTypes.Concatenation.Middle.create_bare list_variadic))); + VariadicExpression + (Group + (Type.OrderedTypes.Concatenation + (Type.OrderedTypes.Concatenation.create + ~tail:[Type.literal_integer 2] + (Type.OrderedTypes.Concatenation.Middle.create_bare list_variadic)))); ]) ~target:"DTypedTensor") (Some [ Single Type.integer; - Group - (Type.OrderedTypes.Concatenation - (Type.OrderedTypes.Concatenation.create - ~tail:[Type.literal_integer 2] - (Type.OrderedTypes.Concatenation.Middle.create_bare list_variadic))); + VariadicExpression + (Group + (Type.OrderedTypes.Concatenation + (Type.OrderedTypes.Concatenation.create + ~tail:[Type.literal_integer 2] + (Type.OrderedTypes.Concatenation.Middle.create_bare list_variadic)))); ]); assert_equal ~printer diff --git a/analysis/test/constraintsSetTest.ml b/analysis/test/constraintsSetTest.ml index 82016f58607..f1f5e8e0d29 100644 --- a/analysis/test/constraintsSetTest.ml +++ b/analysis/test/constraintsSetTest.ml @@ -250,10 +250,10 @@ let test_add_constraint context = in let parse_ordered_types ordered = if String.equal ordered "" then - Type.OrderedTypes.Concrete [] + Type.OrderedTypes.Group (Concrete []) else match parse_annotation (Printf.sprintf "typing.Tuple[%s]" ordered) with - | Type.Tuple (Bounded ordered) -> ordered + | Type.Tuple (Bounded ordered) -> Group ordered | _ -> failwith "impossible" in let global_resolution = @@ -1114,7 +1114,7 @@ let test_instantiate_protocol_parameters context = ~protocols:["VariadicCol", ["prop", "typing.Tuple[Ts]"]] ~candidate:"A" ~protocol:"VariadicCol" - (Some [Group (Concrete [Type.integer; Type.string])]); + (Some [VariadicExpression (Group (Concrete [Type.integer; Type.string]))]); assert_instantiate_protocol_parameters ~source: {| @@ -1125,7 +1125,7 @@ let test_instantiate_protocol_parameters context = ~protocols:["VariadicCol", ["method", "typing.Callable[Ts, bool]"]] ~candidate:"A" ~protocol:"VariadicCol" - (Some [Group (Concrete [Type.integer; Type.string])]); + (Some [VariadicExpression (Group (Concrete [Type.integer; Type.string]))]); () diff --git a/analysis/test/integration/annotationTest.ml b/analysis/test/integration/annotationTest.ml index 37283a42ad0..f8e9d2999f2 100644 --- a/analysis/test/integration/annotationTest.ml +++ b/analysis/test/integration/annotationTest.ml @@ -2561,6 +2561,292 @@ let test_check_typevar_division_simplify context = ] +let test_check_broadcast_outside_concatenation context = + let assert_type_errors = assert_type_errors ~context in + assert_type_errors + {| + from typing import Generic + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + + Ts0 = ListVariadic("Ts0") + Ts1 = ListVariadic("Ts1") + + class Tensor(Generic[Ts0]): ... + + def f(x : Tensor[Ts0], y : Tensor[Ts1]) -> Tensor[BC[Ts0,Ts1]]: ... + + def foo() -> None: + x1 : Tensor[L[2],L[3]] + y1 : Tensor[L[2],L[3]] + z1 = f(x1,y1) + + x2 : Tensor[L[1],L[3]] + y2 : Tensor[L[2],L[1]] + z2 = f(x2,y2) + + x3 : Tensor[L[4],L[5],L[1],L[3]] + y3 : Tensor[L[2],L[1]] + z3 = f(x3,y3) + + reveal_type(z1) + reveal_type(z2) + reveal_type(z3) + |} + [ + "Revealed type [-1]: Revealed type for `z1` is `Tensor[typing_extensions.Literal[2], \ + typing_extensions.Literal[3]]`."; + "Revealed type [-1]: Revealed type for `z2` is `Tensor[typing_extensions.Literal[2], \ + typing_extensions.Literal[3]]`."; + "Revealed type [-1]: Revealed type for `z3` is `Tensor[typing_extensions.Literal[4], \ + typing_extensions.Literal[5], typing_extensions.Literal[2], typing_extensions.Literal[3]]`."; + ]; + assert_type_errors + {| + from typing import Generic + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + + Ts0 = ListVariadic("Ts0") + Ts1 = ListVariadic("Ts1") + + class Tensor(Generic[Ts0]): ... + + def f(x : Tensor[Ts0], y : Tensor[Ts1]) -> Tensor[BC[Ts0,Ts1]]: ... + + def foo() -> None: + x1 : Tensor[L[2],L[4]] + y1 : Tensor[L[2],L[3]] + z1 = f(x1,y1) + + x2 : Tensor[L[2],int] + y2 : Tensor[L[2],L[3]] + z2 = f(x2,y2) + + x3 : Tensor[L[2],str] + y3 : Tensor[L[2],L[3]] + z3 = f(x3,y3) + + reveal_type(z1) + reveal_type(z2) + reveal_type(z3) + |} + [ + "Revealed type [-1]: Revealed type for `z1` is `Tensor[typing_extensions.Literal[2], \ + typing.Any]`."; + "Revealed type [-1]: Revealed type for `z2` is `Tensor[typing_extensions.Literal[2], int]`."; + "Revealed type [-1]: Revealed type for `z3` is `Tensor[typing_extensions.Literal[2], \ + typing.Any]`."; + ]; + assert_type_errors + {| + from typing import Generic + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + + Ts0 = ListVariadic("Ts0") + Ts1 = ListVariadic("Ts1") + Ts2 = ListVariadic("Ts2") + + class Tensor(Generic[Ts0]): ... + + def f(x : Tensor[Ts0], y : Tensor[Ts1], z : Tensor[Ts2]) -> Tensor[BC[Ts0,BC[Ts1,Ts2]]]: ... + + def foo() -> None: + x1 : Tensor[L[2],L[3]] + y1 : Tensor[L[2],L[3]] + z1 : Tensor[L[2],L[3]] + v1 = f(x1,y1,z1) + + x2 : Tensor[L[3],L[2],L[1]] + y2 : Tensor[L[1],L[2],L[1]] + z2 : Tensor[L[1],L[2],L[4]] + v2 = f(x2,y2,z2) + + reveal_type(v1) + reveal_type(v2) + |} + [ + "Revealed type [-1]: Revealed type for `v1` is `Tensor[typing_extensions.Literal[2], \ + typing_extensions.Literal[3]]`."; + "Revealed type [-1]: Revealed type for `v2` is `Tensor[typing_extensions.Literal[3], \ + typing_extensions.Literal[2], typing_extensions.Literal[4]]`."; + ]; + assert_type_errors + {| + from typing import Generic + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + + Ts = ListVariadic("Ts") + + class Tensor(Generic[Ts]): ... + + def f(x : Tensor[Ts]) -> Tensor[BC[[L[1],L[3]],Ts]]: ... + + def foo() -> None: + x1 : Tensor[L[2],L[3]] + y1 = f(x1) + reveal_type(y1) + + x2 : Tensor[L[2],L[4]] + y2 = f(x2) + reveal_type(y2) + |} + [ + "Revealed type [-1]: Revealed type for `y1` is `Tensor[typing_extensions.Literal[2], \ + typing_extensions.Literal[3]]`."; + "Revealed type [-1]: Revealed type for `y2` is `Tensor[typing_extensions.Literal[2], \ + typing.Any]`."; + ]; + assert_type_errors + {| + from typing import TypeVar, Generic + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + from pyre_extensions.type_variable_operators import Concatenate as Cat + + A = TypeVar("A",bound=int) + Ts = ListVariadic("Ts") + Ts1 = ListVariadic("Ts1") + + class Tensor(Generic[Ts]): ... + + def f(x : Tensor[Cat[A,Ts]], y : Tensor[Ts1]) -> Tensor[BC[Cat[A,Ts],Ts1]]: ... + + def foo() -> None: + x1 : Tensor[L[1],L[3]] + x2 : Tensor[L[2],L[1]] + y1 = f(x1,x2) + reveal_type(y1) + |} + [ + "Revealed type [-1]: Revealed type for `y1` is `Tensor[typing_extensions.Literal[2], \ + typing_extensions.Literal[3]]`."; + ] + + +let test_check_broadcast_inside_concatenation context = + let assert_type_errors = assert_type_errors ~context in + assert_type_errors + {| + from typing import Generic, TypeVar + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + from pyre_extensions.type_variable_operators import Concatenate as Cat + + + Ts = ListVariadic("Ts") + Ts1 = ListVariadic("Ts1") + Ts2 = ListVariadic("Ts2") + A = TypeVar("A",bound=int) + + class Tensor(Generic[Ts]): ... + + def g(x : Tensor[Ts], a: A) -> Tensor[Cat[Ts,A]]: ... + + def h(x : Tensor[Ts1], y : Tensor[Ts2]) -> Tensor[BC[Ts1,Ts2]]: ... + + def f(t1 : Tensor[Ts1], t2 : Tensor[Ts2]) -> Tensor[Cat[BC[Ts1,Ts2],L[3]]]: + x = h(t1,t2) + y = g(x,3) + return y + + def foo() -> None: + t1 : Tensor[L[2],L[4]] + t2 : Tensor[L[2],L[1]] + result = f(t1,t2) + reveal_type(result) + |} + [ + "Revealed type [-1]: Revealed type for `result` is `Tensor[typing_extensions.Literal[2], \ + typing_extensions.Literal[4], typing_extensions.Literal[3]]`."; + ]; + assert_type_errors + {| + from typing import Generic, TypeVar + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + from pyre_extensions.type_variable_operators import Concatenate as Cat + + Ts1 = ListVariadic("Ts1") + Ts2 = ListVariadic("Ts2") + A = TypeVar("A",bound=int) + B = TypeVar("B",bound=int) + C = TypeVar("C",bound=int) + + class Tensor(Generic[Ts1]): ... + + def matmul(x: Tensor[Cat[Ts1,A,B]], y: Tensor[Cat[Ts2,B,C]]) -> Tensor[Cat[BC[Ts1,Ts2],A,C]]: ... + + def foo() -> None: + t1 : Tensor[L[5],L[2],L[4]] + t2 : Tensor[L[7],L[1],L[4],L[3]] + x = matmul(t1,t2) + reveal_type(x) + |} + [ + "Revealed type [-1]: Revealed type for `x` is `Tensor[typing_extensions.Literal[7], \ + typing_extensions.Literal[5], typing_extensions.Literal[2], typing_extensions.Literal[3]]`."; + ]; + assert_type_errors + {| + from typing import Generic + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + from pyre_extensions.type_variable_operators import Concatenate as Cat + + Ts1 = ListVariadic("Ts1") + Ts2 = ListVariadic("Ts2") + + class Tensor(Generic[Ts1]): ... + + def f(x : Tensor[Ts1], y : Tensor[Ts2]) -> Tensor[ + BC[ + Cat[BC[Ts2,Ts1],L[1]], + Cat[BC[Ts1,Ts2],L[1]] + ] + ]: ... + + def foo() -> None: + t1 : Tensor[L[2],L[4]] + t2 : Tensor[L[1],L[4]] + x = f(t1,t2) + reveal_type(x) + |} + [ + "Revealed type [-1]: Revealed type for `x` is `Tensor[typing_extensions.Literal[2], \ + typing_extensions.Literal[4], typing_extensions.Literal[1]]`."; + ]; + assert_type_errors + {| + from typing import Generic + from typing_extensions import Literal as L + from pyre_extensions import Broadcast as BC, ListVariadic + from pyre_extensions.type_variable_operators import Concatenate as Cat + + Ts1 = ListVariadic("Ts1") + Ts2 = ListVariadic("Ts2") + + class Tensor(Generic[Ts1]): ... + + def f(t: Tensor[BC[Ts1,Ts2]]) -> Tensor[Ts1]: ... + + def foo() -> None: + t1 : Tensor[L[5],L[2],L[4]] + x = f(t1) + reveal_type(x) + |} + [ + "Incomplete type [37]: Type `Tensor[test.Ts1]` inferred for `x` is incomplete, add an \ + explicit annotation."; + "Incompatible parameter type [6]: Expected \ + `Tensor[pyre_extensions.Broadcast[test.Ts1,test.Ts2]]` for 1st positional only parameter to \ + call `f` but got `Tensor[int, int, int]`."; + "Revealed type [-1]: Revealed type for `x` is `Tensor[...]`."; + ] + + let () = "annotation" >::: [ @@ -2585,5 +2871,7 @@ let () = "check_variadic_arithmetic" >:: test_check_variadic_arithmetic; "check_typevar_division" >:: test_check_typevar_division; "check_typevar_division_simplify" >:: test_check_typevar_division_simplify; + "check_broadcast_outside_concatenation" >:: test_check_broadcast_outside_concatenation; + "check_broadcast_inside_concatenation" >:: test_check_broadcast_inside_concatenation; ] |> Test.run diff --git a/analysis/test/typeCheckTest.ml b/analysis/test/typeCheckTest.ml index 69bb1541e57..3e04dc6a3f1 100644 --- a/analysis/test/typeCheckTest.ml +++ b/analysis/test/typeCheckTest.ml @@ -846,7 +846,8 @@ let test_forward_expression context = pass |} "test.MyVariadic[int]" - (Type.meta (Type.parametric "test.MyVariadic" [Group (Concrete [Type.integer])])); + (Type.meta + (Type.parametric "test.MyVariadic" [VariadicExpression (Group (Concrete [Type.integer]))])); assert_forward ~environment: {| @@ -856,7 +857,10 @@ let test_forward_expression context = pass |} "test.MyVariadic[int, str]" - (Type.meta (Type.parametric "test.MyVariadic" [Group (Concrete [Type.integer; Type.string])])); + (Type.meta + (Type.parametric + "test.MyVariadic" + [VariadicExpression (Group (Concrete [Type.integer; Type.string]))])); (* We'd like for this to return typing.Type[MyVariadic[int, [bool, float]]], but we lose track of the inner types of the lists, since they're not tuples. *) assert_forward @@ -869,7 +873,8 @@ let test_forward_expression context = pass |} "test.MyVariadic[int, [bool, float]]" - (Type.meta (Type.parametric "test.MyVariadic" [Single Type.integer; Group Any])); + (Type.meta + (Type.parametric "test.MyVariadic" [Single Type.integer; VariadicExpression (Group Any)])); (* Resolved annotation field. *) let assert_annotation ?(precondition = []) ?(environment = "") expression annotation = diff --git a/analysis/test/typeConstraintsTest.ml b/analysis/test/typeConstraintsTest.ml index 2eace2e5948..f29a866d45b 100644 --- a/analysis/test/typeConstraintsTest.ml +++ b/analysis/test/typeConstraintsTest.ml @@ -23,16 +23,19 @@ let right_parent = Type.Primitive "right_parent" let grandparent = Type.Primitive "Grandparent" -let create_concatenation ?head ?tail ?mappers variable - : (Type.t Type.OrderedTypes.Concatenation.Middle.t, Type.t) Type.OrderedTypes.Concatenation.t - = +let create_concatenation ?head ?tail ?mappers variable = let mappers = Option.value mappers ~default:[] in - Type.OrderedTypes.Concatenation.create - ?head - ?tail - (Type.OrderedTypes.Concatenation.Middle.create ~mappers ~variable) + let variable = Type.OrderedTypes.Concatenation.Middle.Variadic variable in + Type.OrderedTypes.Group + (Concatenation + (Type.OrderedTypes.Concatenation.create + ?head + ?tail + (Type.OrderedTypes.Concatenation.Middle.create ~mappers ~variable))) +let wrap_concrete list = Type.OrderedTypes.Group (Concrete list) + module DiamondOrder = struct type t = unit @@ -138,25 +141,19 @@ let test_add_bound _ = ))); let list_variadic = Type.Variable.Variadic.List.create in assert_add_bound_succeeds - (`Lower - (ListVariadicPair (list_variadic "Ts", Type.OrderedTypes.Concrete [Type.integer; Type.string]))); + (`Lower (ListVariadicPair (list_variadic "Ts", wrap_concrete [Type.integer; Type.string]))); assert_add_bound_succeeds ~preconstraints: (add_bound (Some empty) - (`Lower - (ListVariadicPair - (list_variadic "Ts", Type.OrderedTypes.Concrete [Type.integer; Type.string])))) - (`Lower - (ListVariadicPair (list_variadic "Ts", Type.OrderedTypes.Concrete [Type.bool; Type.bool]))); + (`Lower (ListVariadicPair (list_variadic "Ts", wrap_concrete [Type.integer; Type.string])))) + (`Lower (ListVariadicPair (list_variadic "Ts", wrap_concrete [Type.bool; Type.bool]))); assert_add_bound_fails ~preconstraints: (add_bound (Some empty) - (`Lower - (ListVariadicPair - (list_variadic "Ts", Type.OrderedTypes.Concrete [Type.integer; Type.string])))) - (`Lower (ListVariadicPair (list_variadic "Ts", Type.OrderedTypes.Concrete [Type.bool]))); + (`Lower (ListVariadicPair (list_variadic "Ts", wrap_concrete [Type.integer; Type.string])))) + (`Lower (ListVariadicPair (list_variadic "Ts", wrap_concrete [Type.bool]))); () @@ -307,35 +304,34 @@ let test_single_variable_solution _ = let list_variadic = Type.Variable.Variadic.List.create "Ts" in assert_solution ~sequentially_applied_bounds: - [`Lower (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [left_parent; child]))] - (Some [ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [left_parent; child])]); + [`Lower (ListVariadicPair (list_variadic, wrap_concrete [left_parent; child]))] + (Some [ListVariadicPair (list_variadic, wrap_concrete [left_parent; child])]); assert_solution ~sequentially_applied_bounds: [ - `Lower (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [left_parent; child])); - `Lower (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [right_parent; child])); + `Lower (ListVariadicPair (list_variadic, wrap_concrete [left_parent; child])); + `Lower (ListVariadicPair (list_variadic, wrap_concrete [right_parent; child])); ] - (Some [ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [grandparent; child])]); + (Some [ListVariadicPair (list_variadic, wrap_concrete [grandparent; child])]); assert_solution ~sequentially_applied_bounds: [ - `Lower (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [left_parent; child])); - `Lower - (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [right_parent; child; child])); + `Lower (ListVariadicPair (list_variadic, wrap_concrete [left_parent; child])); + `Lower (ListVariadicPair (list_variadic, wrap_concrete [right_parent; child; child])); ] None; assert_solution ~sequentially_applied_bounds: [ - `Lower (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [left_parent; child])); - `Upper (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [grandparent; child])); + `Lower (ListVariadicPair (list_variadic, wrap_concrete [left_parent; child])); + `Upper (ListVariadicPair (list_variadic, wrap_concrete [grandparent; child])); ] - (Some [ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [left_parent; child])]); + (Some [ListVariadicPair (list_variadic, wrap_concrete [left_parent; child])]); assert_solution ~sequentially_applied_bounds: [ - `Lower (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [left_parent; child])); - `Upper (ListVariadicPair (list_variadic, Type.OrderedTypes.Concrete [right_parent; child])); + `Lower (ListVariadicPair (list_variadic, wrap_concrete [left_parent; child])); + `Upper (ListVariadicPair (list_variadic, wrap_concrete [right_parent; child])); ] None; assert_solution @@ -465,15 +461,13 @@ let test_multiple_variable_solution _ = assert_solution ~sequentially_applied_bounds: [ - `Lower - (ListVariadicPair (list_variadic_a, Concatenation (create_concatenation list_variadic_b))); - `Lower - (ListVariadicPair (list_variadic_b, Type.OrderedTypes.Concrete [Type.integer; Type.string])); + `Lower (ListVariadicPair (list_variadic_a, create_concatenation list_variadic_b)); + `Lower (ListVariadicPair (list_variadic_b, wrap_concrete [Type.integer; Type.string])); ] (Some [ - ListVariadicPair (list_variadic_a, Type.OrderedTypes.Concrete [Type.integer; Type.string]); - ListVariadicPair (list_variadic_b, Type.OrderedTypes.Concrete [Type.integer; Type.string]); + ListVariadicPair (list_variadic_a, wrap_concrete [Type.integer; Type.string]); + ListVariadicPair (list_variadic_b, wrap_concrete [Type.integer; Type.string]); ]); (* As with unaries, this trivial loop could be solvable, but we are choosing not to deal with this @@ -481,44 +475,38 @@ let test_multiple_variable_solution _ = assert_solution ~sequentially_applied_bounds: [ - `Lower - (ListVariadicPair (list_variadic_a, Concatenation (create_concatenation list_variadic_b))); - `Lower - (ListVariadicPair (list_variadic_b, Concatenation (create_concatenation list_variadic_a))); + `Lower (ListVariadicPair (list_variadic_a, create_concatenation list_variadic_b)); + `Lower (ListVariadicPair (list_variadic_b, create_concatenation list_variadic_a)); ] None; assert_solution ~sequentially_applied_bounds: [ - `Lower - (ListVariadicPair - (list_variadic_a, Type.OrderedTypes.Concrete [Type.Variable unconstrained_a])); + `Lower (ListVariadicPair (list_variadic_a, wrap_concrete [Type.Variable unconstrained_a])); `Lower (UnaryPair (unconstrained_a, Type.integer)); ] (Some [ - ListVariadicPair (list_variadic_a, Type.OrderedTypes.Concrete [Type.integer]); + ListVariadicPair (list_variadic_a, wrap_concrete [Type.integer]); UnaryPair (unconstrained_a, Type.integer); ]); assert_solution ~sequentially_applied_bounds: [ `Lower - (ListVariadicPair - (list_variadic_a, Concatenation (create_concatenation ~mappers:["Foo"] list_variadic_b))); - `Lower - (ListVariadicPair (list_variadic_b, Type.OrderedTypes.Concrete [Type.integer; Type.string])); + (ListVariadicPair (list_variadic_a, create_concatenation ~mappers:["Foo"] list_variadic_b)); + `Lower (ListVariadicPair (list_variadic_b, wrap_concrete [Type.integer; Type.string])); ] (Some [ ListVariadicPair ( list_variadic_a, - Concrete + wrap_concrete [ - Parametric { name = "Foo"; parameters = [Single Type.integer] }; + Type.Parametric { name = "Foo"; parameters = [Single Type.integer] }; Parametric { name = "Foo"; parameters = [Single Type.string] }; ] ); - ListVariadicPair (list_variadic_b, Type.OrderedTypes.Concrete [Type.integer; Type.string]); + ListVariadicPair (list_variadic_b, wrap_concrete [Type.integer; Type.string]); ]); assert_solution ~sequentially_applied_bounds: @@ -609,12 +597,10 @@ let test_partial_solution _ = ~variables:[Type.Variable.ListVariadic list_variadic_a] ~bounds: [ - `Lower - (ListVariadicPair (list_variadic_a, Concatenation (create_concatenation list_variadic_b))); - `Lower - (ListVariadicPair (list_variadic_b, Concatenation (create_concatenation list_variadic_a))); + `Lower (ListVariadicPair (list_variadic_a, create_concatenation list_variadic_b)); + `Lower (ListVariadicPair (list_variadic_b, create_concatenation list_variadic_a)); ] - (Some [ListVariadicPair (list_variadic_a, Concatenation (create_concatenation list_variadic_b))]) + (Some [ListVariadicPair (list_variadic_a, create_concatenation list_variadic_b)]) (Some []); () @@ -660,8 +646,7 @@ let test_exists _ = let list_variadic_b = Type.Variable.Variadic.List.create "TsB" in let constraints_with_list_variadic_b = let pair = - Type.Variable.ListVariadicPair - (list_variadic_a, Concatenation (create_concatenation list_variadic_b)) + Type.Variable.ListVariadicPair (list_variadic_a, create_concatenation list_variadic_b) in DiamondOrderedConstraints.add_lower_bound TypeConstraints.empty ~order ~pair |> function diff --git a/analysis/test/typeOrderTest.ml b/analysis/test/typeOrderTest.ml index 020eb7b96c8..f6dcef6f929 100644 --- a/analysis/test/typeOrderTest.ml +++ b/analysis/test/typeOrderTest.ml @@ -657,9 +657,14 @@ let test_less_or_equal context = ~left:(Type.tuple [Type.integer; Type.float]) ~right:(Type.Tuple (Type.Unbounded Type.float))); let list_variadic = - Type.Variable.Variadic.List.create "Ts" - |> Type.Variable.Variadic.List.mark_as_bound - |> Type.Variable.Variadic.List.self_reference + let variadic = + Type.Variable.Variadic.List.create "Ts" |> Type.Variable.Variadic.List.mark_as_bound + in + let middle = + Type.OrderedTypes.Concatenation.Middle.Variadic variadic + |> Type.OrderedTypes.Concatenation.Middle.create_bare + in + Type.OrderedTypes.Concatenation (Type.OrderedTypes.Concatenation.create middle) in assert_false (less_or_equal diff --git a/analysis/test/typeTest.ml b/analysis/test/typeTest.ml index cc88bd194ae..79f1320e727 100644 --- a/analysis/test/typeTest.ml +++ b/analysis/test/typeTest.ml @@ -20,12 +20,23 @@ let create_concatenation ?head ?tail ?mappers variable : (Type.t Type.OrderedTypes.Concatenation.Middle.t, Type.t) Type.OrderedTypes.Concatenation.t = let mappers = Option.value mappers ~default:[] in + let variable = Type.OrderedTypes.Concatenation.Middle.Variadic variable in Type.OrderedTypes.Concatenation.create ?head ?tail (Type.OrderedTypes.Concatenation.Middle.create ~mappers ~variable) +let wrap_expression_concatenation concatenation = + Type.Parameter.VariadicExpression (Group (Concatenation concatenation)) + + +let wrap_group_concatenation concatenation = Type.OrderedTypes.Group (Concatenation concatenation) + +let wrap_group_concrete concrete = Type.OrderedTypes.Group (Concrete concrete) + +let wrap_variadic variadic = Type.OrderedTypes.Concatenation.Middle.Variadic variadic + let test_create _ = let assert_create ?(aliases = fun _ -> None) source annotation = assert_equal @@ -49,7 +60,9 @@ let test_create _ = | "Ts" -> Some (VariableAlias (Type.Variable.ListVariadic variadic)) | _ -> None) "foo[Ts]" - (Type.parametric "foo" [Group (Type.Variable.Variadic.List.self_reference variadic)]); + (Type.parametric + "foo" + [VariadicExpression (Type.Variable.Variadic.List.self_reference variadic)]); assert_create ~aliases:(function | "Ts" -> Some (VariableAlias (Type.Variable.ListVariadic variadic)) @@ -57,7 +70,7 @@ let test_create _ = "foo[pyre_extensions.type_variable_operators.Map[typing.List, Ts]]" (Type.parametric "foo" - [Group (Concatenation (create_concatenation ~mappers:["list"] variadic))]); + [wrap_expression_concatenation (create_concatenation ~mappers:["list"] variadic)]); assert_create ~aliases:(function | "Ts" -> Some (VariableAlias (Type.Variable.ListVariadic variadic)) @@ -65,7 +78,10 @@ let test_create _ = "foo[pyre_extensions.type_variable_operators.Concatenate[int, bool, Ts]]" (Type.parametric "foo" - [Group (Concatenation (create_concatenation ~head:[Type.integer; Type.bool] variadic))]); + [ + wrap_expression_concatenation + (create_concatenation ~head:[Type.integer; Type.bool] variadic); + ]); assert_create ~aliases:(function | "Ts" -> Some (VariableAlias (Type.Variable.ListVariadic variadic)) @@ -73,7 +89,10 @@ let test_create _ = "foo[pyre_extensions.type_variable_operators.Concatenate[Ts, int, bool]]" (Type.parametric "foo" - [Group (Concatenation (create_concatenation ~tail:[Type.integer; Type.bool] variadic))]); + [ + wrap_expression_concatenation + (create_concatenation ~tail:[Type.integer; Type.bool] variadic); + ]); assert_create ~aliases:(function | "Ts" -> Some (VariableAlias (Type.Variable.ListVariadic variadic)) @@ -82,7 +101,8 @@ let test_create _ = (Type.parametric "foo" [ - Group (Concatenation (create_concatenation ~head:[Type.integer] ~tail:[Type.bool] variadic)); + wrap_expression_concatenation + (create_concatenation ~head:[Type.integer] ~tail:[Type.bool] variadic); ]); assert_create ~aliases:(function @@ -93,20 +113,15 @@ let test_create _ = (Type.parametric "foo" [ - Group - (Concatenation - (create_concatenation - ~head:[Type.integer] - ~tail:[Type.bool] - ~mappers:["list"] - variadic)); + wrap_expression_concatenation + (create_concatenation ~head:[Type.integer] ~tail:[Type.bool] ~mappers:["list"] variadic); ]); assert_create ~aliases:(function | "Ts" -> Some (VariableAlias (Type.Variable.ListVariadic variadic)) | _ -> None) "foo[...]" - (Type.parametric "foo" [Group Any]); + (Type.parametric "foo" [VariadicExpression (Group Any)]); assert_create "typing.List.__getitem__(int)" (Type.list Type.integer); assert_create "typing.Dict.__getitem__((int, str))" @@ -514,9 +529,15 @@ let test_expression _ = "typing.Tuple.__getitem__((int, ...))"; assert_expression (Type.parametric "list" ![Type.integer]) "typing.List.__getitem__(int)"; assert_expression - (Type.parametric - "foo.Variadic" - [Group (Concatenation (create_concatenation (Type.Variable.Variadic.List.create "Ts")))]) + (Type.Parametric + { + name = "foo.Variadic"; + parameters = + [ + wrap_expression_concatenation + (create_concatenation (Type.Variable.Variadic.List.create "Ts")); + ]; + }) "foo.Variadic.__getitem__(Ts)"; assert_expression (Type.parametric "foo.Variadic" [Group Any]) "foo.Variadic.__getitem__(...)"; assert_expression @@ -1935,7 +1956,7 @@ let test_replace_all _ = in assert_equal (Type.Variable.GlobalTransforms.ListVariadic.replace_all - (fun _ -> Some (Type.OrderedTypes.Concrete [Type.integer; Type.string])) + (fun _ -> Some (wrap_group_concrete [Type.integer; Type.string])) (Type.parametric "p" ![Type.integer; free_variable_tuple])) (Type.parametric "p" @@ -1950,7 +1971,7 @@ let test_replace_all _ = in assert_equal (Type.Variable.GlobalTransforms.ListVariadic.replace_all - (fun _ -> Some (Type.OrderedTypes.Concrete [Type.integer; Type.string])) + (fun _ -> Some (wrap_group_concrete [Type.integer; Type.string])) (Type.Callable.create ~parameters: (Defined @@ -1963,7 +1984,7 @@ let test_replace_all _ = (Type.Callable.create ~parameters:(Defined replaced) ~annotation:Type.integer ()); assert_equal (Type.Variable.GlobalTransforms.ListVariadic.replace_all - (fun _ -> Some (Type.OrderedTypes.Concrete [Type.integer; Type.string])) + (fun _ -> Some (wrap_group_concrete [Type.integer; Type.string])) (Tuple (Bounded (Concatenation (create_concatenation ~mappers:["Foo"; "Bar"] list_variadic))))) (Tuple (Bounded @@ -1974,11 +1995,14 @@ let test_replace_all _ = ]))); assert_equal (Type.Variable.GlobalTransforms.ListVariadic.replace_all - (fun _ -> Some (Type.OrderedTypes.Concrete [Type.integer; Type.string])) + (fun _ -> Some (wrap_group_concrete [Type.integer; Type.string])) (Type.parametric "Baz" - [Group (Concatenation (create_concatenation (Type.Variable.Variadic.List.create "Ts")))])) - (Type.parametric "Baz" [Group (Concrete [Type.integer; Type.string])]); + [ + wrap_expression_concatenation + (create_concatenation (Type.Variable.Variadic.List.create "Ts")); + ])) + (Type.parametric "Baz" [VariadicExpression (Group (Concrete [Type.integer; Type.string]))]); assert_equal (Type.Variable.GlobalTransforms.ParameterVariadic.replace_all (fun _ -> @@ -2045,7 +2069,10 @@ let test_collect_all _ = (Type.Variable.GlobalTransforms.ListVariadic.collect_all (Type.parametric "Huh" - [Group (Concatenation (create_concatenation ~mappers:["Foo"; "Bar"] list_variadic))])) + [ + wrap_expression_concatenation + (create_concatenation ~mappers:["Foo"; "Bar"] list_variadic); + ])) [Type.Variable.Variadic.List.create "Ts"]; assert_equal (Type.Variable.GlobalTransforms.ParameterVariadic.collect_all @@ -2094,7 +2121,10 @@ let test_middle_singleton_replace_variable _ = in let variable = Type.Variable.Variadic.List.create "Ts" in assert_replaces_into - ~middle:(Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"; "Bar"] ~variable) + ~middle: + (Type.OrderedTypes.Concatenation.Middle.create + ~mappers:["Foo"; "Bar"] + ~variable:(wrap_variadic variable)) ~replacement:Type.integer (Type.parametric "Foo" @@ -2103,7 +2133,10 @@ let test_middle_singleton_replace_variable _ = (* This approach is used to solve concretes against maps *) let unary_variable = Type.Variable.Unary.create "T" in assert_replaces_into - ~middle:(Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"; "Bar"] ~variable) + ~middle: + (Type.OrderedTypes.Concatenation.Middle.create + ~mappers:["Foo"; "Bar"] + ~variable:(wrap_variadic variable)) ~replacement:(Type.Variable unary_variable) (Type.parametric "Foo" @@ -2116,16 +2149,18 @@ let test_union_upper_bound _ = assert_equal (Type.OrderedTypes.union_upper_bound map) expected in assert_union_upper_bound - (Concrete [Type.integer; Type.string; Type.bool]) + (wrap_group_concrete [Type.integer; Type.string; Type.bool]) (Type.union [Type.integer; Type.string; Type.bool]); - assert_union_upper_bound Any Any; + assert_union_upper_bound (Group Any) Any; let variable = Type.Variable.Variadic.List.create "Ts" in - assert_union_upper_bound (Concatenation (create_concatenation variable)) Type.object_primitive; + assert_union_upper_bound + (wrap_group_concatenation (create_concatenation variable)) + Type.object_primitive; assert_union_upper_bound - (Concatenation (create_concatenation ~mappers:["Foo"; "Bar"] variable)) + (wrap_group_concatenation (create_concatenation ~mappers:["Foo"; "Bar"] variable)) Type.object_primitive; () @@ -2135,37 +2170,44 @@ let test_concatenation_operator_variable _ = assert_equal (Type.OrderedTypes.Concatenation.variable (create_concatenation ~head:[Type.integer] ~tail:[Type.bool] variable)) - variable; + (wrap_variadic variable); assert_equal (Type.OrderedTypes.Concatenation.variable (create_concatenation ~head:[Type.integer] ~tail:[Type.bool] ~mappers:["list"] variable)) - variable; + (wrap_variadic variable); () let test_concatenation_operator_replace_variable _ = let assert_replaces_into ~concatenation ~replacement expected = assert_equal - (Type.OrderedTypes.Concatenation.replace_variable concatenation ~replacement) + (Type.OrderedTypes.Concatenation.replace_variable + concatenation + ~replacement + ~replace_variadic:Fn.id) expected in let variable = Type.Variable.Variadic.List.create "Ts" in assert_replaces_into ~concatenation:(create_concatenation ~head:[Type.integer] ~tail:[Type.bool] variable) - ~replacement:(fun _ -> Some (Concrete [Type.string])) - (Some (Concrete [Type.integer; Type.string; Type.bool])); + ~replacement:(fun _ -> Some (wrap_group_concrete [Type.string])) + (Some (wrap_group_concrete [Type.integer; Type.string; Type.bool])); let assert_replaces_into ~middle ~replacement expected = let concatenation = Type.OrderedTypes.Concatenation.create ?head:None ?tail:None middle in assert_replaces_into ~concatenation ~replacement expected in let variable = Type.Variable.Variadic.List.create "Ts" in + let wrapped_variable = wrap_variadic variable in assert_replaces_into - ~middle:(Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"; "Bar"] ~variable) - ~replacement:(fun _ -> Some (Concrete [Type.integer])) + ~middle: + (Type.OrderedTypes.Concatenation.Middle.create + ~mappers:["Foo"; "Bar"] + ~variable:wrapped_variable) + ~replacement:(fun _ -> Some (wrap_group_concrete [Type.integer])) (Some - (Concrete + (wrap_group_concrete [ - Parametric + Type.Parametric { name = "Foo"; parameters = @@ -2173,12 +2215,15 @@ let test_concatenation_operator_replace_variable _ = }; ])); assert_replaces_into - ~middle:(Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"; "Bar"] ~variable) - ~replacement:(fun _ -> Some (Concrete [Type.integer; Type.string])) + ~middle: + (Type.OrderedTypes.Concatenation.Middle.create + ~mappers:["Foo"; "Bar"] + ~variable:wrapped_variable) + ~replacement:(fun _ -> Some (wrap_group_concrete [Type.integer; Type.string])) (Some - (Concrete + (wrap_group_concrete [ - Parametric + Type.Parametric { name = "Foo"; parameters = @@ -2187,22 +2232,33 @@ let test_concatenation_operator_replace_variable _ = Parametric { name = "Foo"; parameters = ![Type.parametric "Bar" ![Type.string]] }; ])); assert_replaces_into - ~middle:(Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"] ~variable) - ~replacement:(fun _ -> Some (Concatenation (create_concatenation ~mappers:["Bar"] variable))) - (Some (Concatenation (create_concatenation ~mappers:["Foo"; "Bar"] variable))); + ~middle: + (Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"] ~variable:wrapped_variable) + ~replacement:(fun _ -> + Some (wrap_group_concatenation (create_concatenation ~mappers:["Bar"] variable))) + (Some (wrap_group_concatenation (create_concatenation ~mappers:["Foo"; "Bar"] variable))); let other_variable = Type.Variable.Variadic.List.create "Ts2" in assert_replaces_into - ~middle:(Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"; "Bar"] ~variable) + ~middle: + (Type.OrderedTypes.Concatenation.Middle.create + ~mappers:["Foo"; "Bar"] + ~variable:wrapped_variable) ~replacement:(function - | _ -> Some (Concatenation (create_concatenation other_variable))) - (Some (Concatenation (create_concatenation ~mappers:["Foo"; "Bar"] other_variable))); + | _ -> Some (wrap_group_concatenation (create_concatenation other_variable))) + (Some (wrap_group_concatenation (create_concatenation ~mappers:["Foo"; "Bar"] other_variable))); assert_replaces_into - ~middle:(Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"; "Bar"] ~variable) + ~middle: + (Type.OrderedTypes.Concatenation.Middle.create + ~mappers:["Foo"; "Bar"] + ~variable:wrapped_variable) ~replacement:(function - | _ -> Some Any) - (Some Any); + | _ -> Some (Group Any)) + (Some (Group Any)); assert_replaces_into - ~middle:(Type.OrderedTypes.Concatenation.Middle.create ~mappers:["Foo"; "Bar"] ~variable) + ~middle: + (Type.OrderedTypes.Concatenation.Middle.create + ~mappers:["Foo"; "Bar"] + ~variable:wrapped_variable) ~replacement:(function | _ -> None) None; @@ -2458,11 +2514,13 @@ let test_add_polynomials_with_variadics _ = let x = Type.Variable.Unary.create "x" in let ts = Type.Variable.Variadic.List.create "Ts" + |> wrap_variadic |> Type.OrderedTypes.Concatenation.Middle.create_bare |> Type.OrderedTypes.Concatenation.create in let shape = Type.Variable.Variadic.List.create "Shape" + |> wrap_variadic |> Type.OrderedTypes.Concatenation.Middle.create_bare |> Type.OrderedTypes.Concatenation.create in @@ -2488,11 +2546,13 @@ let test_subtract_polynomials_with_variadics _ = let x = Type.Variable.Unary.create "x" in let ts = Type.Variable.Variadic.List.create "Ts" + |> wrap_variadic |> Type.OrderedTypes.Concatenation.Middle.create_bare |> Type.OrderedTypes.Concatenation.create in let shape = Type.Variable.Variadic.List.create "Shape" + |> wrap_variadic |> Type.OrderedTypes.Concatenation.Middle.create_bare |> Type.OrderedTypes.Concatenation.create in @@ -2522,6 +2582,7 @@ let test_multiply_polynomials_with_variadics _ = let x = Type.Variable.Unary.create "x" in let ts = Type.Variable.Variadic.List.create "Ts" + |> wrap_variadic |> Type.OrderedTypes.Concatenation.Middle.create_bare |> Type.OrderedTypes.Concatenation.create in diff --git a/analysis/type.ml b/analysis/type.ml index f15b8f7ea86..179c10cc791 100644 --- a/analysis/type.ml +++ b/analysis/type.ml @@ -142,8 +142,92 @@ module Record = struct [@@deriving compare, eq, sexp, show, hash] end - module OrderedTypes = struct - let map_public_name = "pyre_extensions.type_variable_operators.Map" + module rec OrderedTypes : sig + module RecordConcatenate : sig + module Middle : sig + type 'annotation variable = + | Variadic of 'annotation Variable.RecordVariadic.RecordList.record + | Expression of 'annotation OrderedTypes.variadic_expression + [@@deriving compare, eq, sexp, show, hash] + + type 'annotation t = { + variable: 'annotation variable; + mappers: Identifier.t list; + } + [@@deriving compare, eq, sexp, show, hash] + + val unwrap_if_bare : 'a t -> 'a variable option + end + + type 'annotation wrapping = { + head: 'annotation list; + tail: 'annotation list; + } + [@@deriving compare, hash] + + type ('middle, 'annotation) t = { + middle: 'middle; + wrapping: 'annotation wrapping; + } + [@@deriving compare, eq, sexp, show, hash] + + val public_name : unit -> string + + val pp_concatenation + : Format.formatter -> + ('a Middle.t, 'a) t -> + pp_type:(Format.formatter -> 'a -> unit) -> + unit + + val unwrap_if_only_middle : ('a, 'b) t -> 'a option + + val empty_wrap : 'a Middle.t -> ('a Middle.t, 'b) t + + val head : ('middle, 'outter) t -> 'outter list + + val middle : ('middle, 'outter) t -> 'middle + + val tail : ('middle, 'outter) t -> 'outter list + end + + type 'annotation record = + | Concrete of 'annotation list + | Any + | Concatenation of ('annotation RecordConcatenate.Middle.t, 'annotation) RecordConcatenate.t + [@@deriving compare, eq, sexp, show, hash] + + type 'annotation variadic_expression = + | Group of 'annotation record + | Broadcast of 'annotation variadic_expression * 'annotation variadic_expression + [@@deriving compare, eq, sexp, show, hash] + + val map_public_name : unit -> string + + val pp_concise + : Format.formatter -> + 'a variadic_expression -> + pp_type:(Format.formatter -> 'a -> unit) -> + unit + + val pp_concise_record + : Format.formatter -> + 'a record -> + pp_type:(Format.formatter -> 'a -> unit) -> + unit + + val concatenate + : left:'a variadic_expression -> + right:'a variadic_expression -> + 'a variadic_expression option + + val concatenate_record : left:'a record -> right:'a record -> 'a record option + + val transform_variadic_expression + : 'a variadic_expression -> + f:('a record -> 'a variadic_expression) -> + 'a variadic_expression + end = struct + let map_public_name () = "pyre_extensions.type_variable_operators.Map" let show_type_list types ~pp_type = Format.asprintf @@ -153,20 +237,27 @@ module Record = struct module RecordConcatenate = struct - let public_name = "pyre_extensions.type_variable_operators.Concatenate" + let public_name () = "pyre_extensions.type_variable_operators.Concatenate" module Middle = struct + type 'annotation variable = + | Variadic of 'annotation Variable.RecordVariadic.RecordList.record + | Expression of 'annotation OrderedTypes.variadic_expression + [@@deriving compare, eq, sexp, show, hash] + type 'annotation t = { - variable: 'annotation Variable.RecordVariadic.RecordList.record; + variable: 'annotation variable; mappers: Identifier.t list; } [@@deriving compare, eq, sexp, show, hash] - let rec show_concise = function - | { variable = { name; _ }; mappers = [] } -> name + let rec show_concise ~pp_type = function + | { variable = Variadic { name; _ }; mappers = [] } -> name + | { variable = Expression variadic_expression; mappers = [] } -> + Format.asprintf "%a" (OrderedTypes.pp_concise ~pp_type) variadic_expression | { mappers = head_mapper :: tail_mappers; _ } as mapped -> let inner = { mapped with mappers = tail_mappers } in - Format.asprintf "Map[%s, %s]" head_mapper (show_concise inner) + Format.asprintf "Map[%s, %s]" head_mapper (show_concise inner ~pp_type) let unwrap_if_bare = function @@ -202,25 +293,26 @@ module Record = struct let pp_concatenation format { middle; wrapping } ~pp_type = match wrapping with - | { head = []; tail = [] } -> Format.fprintf format "%s" (Middle.show_concise middle) + | { head = []; tail = [] } -> + Format.fprintf format "%s" (Middle.show_concise ~pp_type middle) | { head; tail = [] } -> Format.fprintf format "Concatenate[%s, %s]" (show_type_list head ~pp_type) - (Middle.show_concise middle) + (Middle.show_concise ~pp_type middle) | { head = []; tail } -> Format.fprintf format "Concatenate[%s, %s]" - (Middle.show_concise middle) + (Middle.show_concise ~pp_type middle) (show_type_list tail ~pp_type) | { head; tail } -> Format.fprintf format "Concatenate[%s, %s, %s]" (show_type_list head ~pp_type) - (Middle.show_concise middle) + (Middle.show_concise ~pp_type middle) (show_type_list tail ~pp_type) end @@ -230,7 +322,12 @@ module Record = struct | Concatenation of ('annotation RecordConcatenate.Middle.t, 'annotation) RecordConcatenate.t [@@deriving compare, eq, sexp, show, hash] - let pp_concise format variable ~pp_type = + type 'annotation variadic_expression = + | Group of 'annotation record + | Broadcast of 'annotation variadic_expression * 'annotation variadic_expression + [@@deriving compare, eq, sexp, show, hash] + + let pp_concise_record format variable ~pp_type = match variable with | Concrete types -> Format.fprintf format "%s" (show_type_list types ~pp_type) | Any -> Format.fprintf format "..." @@ -238,7 +335,20 @@ module Record = struct Format.fprintf format "%a" (RecordConcatenate.pp_concatenation ~pp_type) concatenation - let concatenate ~left ~right = + let rec pp_concise format variable ~pp_type = + match variable with + | Group variadic -> Format.fprintf format "%a" (pp_concise_record ~pp_type) variadic + | Broadcast (left, right) -> + Format.fprintf + format + "pyre_extensions.Broadcast[%a,%a]" + (pp_concise ~pp_type) + left + (pp_concise ~pp_type) + right + + + let concatenate_record ~left ~right = match left, right with | Concrete left, Concrete right -> Some (Concrete (left @ right)) (* Any can masquerade as the empty list *) @@ -252,6 +362,20 @@ module Record = struct | Concatenation ({ wrapping = { head; tail }; _ } as concatenation), Concrete right -> Some (Concatenation { concatenation with wrapping = { head; tail = tail @ right } }) | Concatenation _, Concatenation _ -> None + + + let concatenate ~left ~right = + match left, right with + | Group left, Group right -> concatenate_record ~left ~right >>| fun result -> Group result + | Broadcast _, Broadcast _ -> None + | _, _ -> None + + + let rec transform_variadic_expression variadic_expression ~f = + match variadic_expression with + | Group ordered -> f ordered + | Broadcast (left, right) -> + Broadcast (transform_variadic_expression ~f left, transform_variadic_expression ~f right) end module Callable = struct @@ -369,14 +493,14 @@ module Record = struct module Parameter = struct type 'annotation record = | Single of 'annotation - | Group of 'annotation OrderedTypes.record + | VariadicExpression of 'annotation OrderedTypes.variadic_expression | CallableParameters of 'annotation Callable.record_parameters [@@deriving compare, eq, sexp, show, hash] let is_single = function | Single single -> Some single | CallableParameters _ - | Group _ -> + | VariadicExpression _ -> None end @@ -1249,7 +1373,7 @@ let parameter_variable_type_representation = function let concretes = head @ [Primitive name] in Parametric { - name = Record.OrderedTypes.RecordConcatenate.public_name; + name = Record.OrderedTypes.RecordConcatenate.public_name (); parameters = List.map concretes ~f:(fun concrete -> Record.Parameter.Single concrete); } @@ -1265,8 +1389,8 @@ let show_callable_parameters ~pp_type = function let pp_parameters ~pp_type format = function - | [Record.Parameter.Group ordered] -> - Format.fprintf format "%a" (Record.OrderedTypes.pp_concise ~pp_type) ordered + | [Record.Parameter.VariadicExpression (Group ordered)] -> + Format.fprintf format "%a" (Record.OrderedTypes.pp_concise_record ~pp_type) ordered | parameters when List.for_all parameters ~f:(function | Single parameter -> is_unbound parameter || is_top parameter @@ -1275,8 +1399,18 @@ let pp_parameters ~pp_type format = function | parameters -> let s format = function | Record.Parameter.Single parameter -> Format.fprintf format "%a" pp_type parameter - | Group ordered_types -> - Format.fprintf format "[%a]" (Record.OrderedTypes.pp_concise ~pp_type) ordered_types + | VariadicExpression (Group ordered_types) -> + Format.fprintf + format + "[%a]" + (Record.OrderedTypes.pp_concise_record ~pp_type) + ordered_types + | VariadicExpression (Broadcast (left, right)) -> + Format.fprintf + format + "%a" + (Record.OrderedTypes.pp_concise ~pp_type) + (Broadcast (left, right)) | CallableParameters parameters -> Format.fprintf format "%s" (show_callable_parameters parameters ~pp_type) in @@ -1329,7 +1463,7 @@ let rec pp format annotation = let parameters = match tuple with | Bounded parameters -> - Format.asprintf "%a" (Record.OrderedTypes.pp_concise ~pp_type:pp) parameters + Format.asprintf "%a" (Record.OrderedTypes.pp_concise_record ~pp_type:pp) parameters | Unbounded parameter -> Format.asprintf "%a, ..." pp parameter in Format.fprintf format "typing.Tuple[%s]" parameters @@ -1420,7 +1554,7 @@ and pp_concise format annotation = Format.fprintf format "Tuple[%a]" - (Record.OrderedTypes.pp_concise ~pp_type:pp_concise) + (Record.OrderedTypes.pp_concise_record ~pp_type:pp_concise) parameters | Tuple (Unbounded parameter) -> Format.fprintf format "Tuple[%a, ...]" pp_concise parameter | Union [NoneType; parameter] @@ -1753,13 +1887,14 @@ let rec expression annotation = | Concatenation concatenation -> [concatenation_expression concatenation] in let expression_of_parameter = function - | Record.Parameter.Group ordered -> + | Record.Parameter.VariadicExpression (Group ordered) -> Node.create ~location (Expression.List (expression_of_ordered ordered)) + | Record.Parameter.VariadicExpression _ -> expression (Primitive "int") (*TODO*) | Single single -> expression single | CallableParameters parameters -> callable_parameters_expression parameters in match parameters with - | [Group ordered] -> expression_of_ordered ordered + | [VariadicExpression (Group ordered)] -> expression_of_ordered ordered | parameters -> List.map parameters ~f:expression_of_parameter in get_item_call (reverse_substitute name) parameters @@ -1826,14 +1961,14 @@ let rec expression annotation = (Parametric { name = "pyre_extensions.Length"; - parameters = [Group (Concatenation variadic)]; + parameters = [VariadicExpression (Group (Concatenation variadic))]; }) | Monomial.Variadic (Product, variadic) -> convert_annotation (Parametric { name = "pyre_extensions.Product"; - parameters = [Group (Concatenation variadic)]; + parameters = [VariadicExpression (Group (Concatenation variadic))]; }) | Monomial.Divide (dividend, quotient) -> convert_annotation @@ -1841,7 +1976,11 @@ let rec expression annotation = { name = "pyre_extensions.Divide"; parameters = - [Group (Concrete [IntExpression dividend; IntExpression quotient])]; + [ + VariadicExpression + (Group + (Concrete [IntExpression dividend; IntExpression quotient])); + ]; }) in @@ -1872,13 +2011,18 @@ and middle_annotation middle = let single_wrap ~mapper ~inner = Parametric { - name = Record.OrderedTypes.map_public_name; + name = Record.OrderedTypes.map_public_name (); parameters = [Single (Primitive mapper); Single inner]; } in match middle with - | { Record.OrderedTypes.RecordConcatenate.Middle.variable = { name; _ }; mappers = [] } -> + | { Record.OrderedTypes.RecordConcatenate.Middle.variable = Variadic { name; _ }; mappers = [] } + -> Primitive name + | { Record.OrderedTypes.RecordConcatenate.Middle.variable = Expression expression; mappers = [] } + -> + Parametric + { name = "pyre_extensions.Broadcast"; parameters = [VariadicExpression expression] } | { mappers = head_mapper :: tail_mappers; _ } -> let inner = { middle with mappers = tail_mappers } in single_wrap ~mapper:head_mapper ~inner:(middle_annotation inner) @@ -1893,7 +2037,7 @@ and concatenation_expression { middle; wrapping } = let concretes = head @ (middle_annotation :: tail) in Parametric { - name = Record.OrderedTypes.RecordConcatenate.public_name; + name = Record.OrderedTypes.RecordConcatenate.public_name (); parameters = List.map concretes ~f:(fun concrete -> Record.Parameter.Single concrete); } in @@ -1982,8 +2126,10 @@ module Transform = struct } | Parametric { name; parameters } -> let visit = function - | Record.Parameter.Group ordered -> - Record.Parameter.Group (visit_ordered_types ordered) + | Parameter.VariadicExpression expression -> + Parameter.VariadicExpression + (Record.OrderedTypes.transform_variadic_expression expression ~f:(fun group -> + Group (visit_ordered_types group))) | Single single -> Single (visit_annotation single ~state) | CallableParameters parameters -> CallableParameters (visit_parameters parameters) in @@ -2301,9 +2447,13 @@ let create_concatenation_operator_from_annotation annotation ~variable_aliases = | Parametric { name; - parameters = [Single (Primitive left_parameter); Group (Concatenation right_parameter)]; + parameters = + [ + Single (Primitive left_parameter); + VariadicExpression (Group (Concatenation right_parameter)); + ]; } - when Identifier.equal name Record.OrderedTypes.map_public_name -> + when Identifier.equal name (Record.OrderedTypes.map_public_name ()) -> let open Record.OrderedTypes.RecordConcatenate in unwrap_if_only_middle right_parameter >>= Middle.unwrap_if_bare @@ -2313,16 +2463,22 @@ let create_concatenation_operator_from_annotation annotation ~variable_aliases = in match annotation with | Parametric { name; parameters } -> ( - match Identifier.equal name Record.OrderedTypes.RecordConcatenate.public_name with + match Identifier.equal name (Record.OrderedTypes.RecordConcatenate.public_name ()) with | true -> ( let parse_as_middle = function - | Record.Parameter.Group (Concatenation potential_middle) -> + | Record.Parameter.VariadicExpression (Group (Concatenation potential_middle)) -> let open Record.OrderedTypes.RecordConcatenate in unwrap_if_only_middle potential_middle | CallableParameters _ - | Group (Concrete _) - | Group Any -> + | VariadicExpression (Group (Concrete _)) + | VariadicExpression (Group Any) -> None + | VariadicExpression expression -> + Some + { + variable = Record.OrderedTypes.RecordConcatenate.Middle.Expression expression; + mappers = []; + } | Record.Parameter.Single potentially_a_map -> create_map_operator_from_annotation potentially_a_map in @@ -2351,7 +2507,10 @@ let create_concatenation_operator_from_annotation annotation ~variable_aliases = | Some (Record.Variable.ListVariadic variable) -> Some (Record.OrderedTypes.RecordConcatenate.empty_wrap - { Record.OrderedTypes.RecordConcatenate.Middle.variable; mappers = [] }) + { + Record.OrderedTypes.RecordConcatenate.Middle.variable = Variadic variable; + mappers = []; + }) | _ -> None ) | _ -> None @@ -2398,7 +2557,7 @@ let rec create_logic ~aliases ~variable_aliases { Node.value = expression; _ } = | Some (ParameterVariadic variable) -> Some { Record.Callable.variable; head = [] } | _ -> None ) | Parametric { name; parameters } - when Identifier.equal name Record.OrderedTypes.RecordConcatenate.public_name -> ( + when Identifier.equal name (Record.OrderedTypes.RecordConcatenate.public_name ()) -> ( match List.rev parameters with | Parameter.CallableParameters (ParameterVariadicTypeVariable { variable; head = [] }) :: reversed_head -> @@ -2667,21 +2826,69 @@ let rec create_logic ~aliases ~variable_aliases { Node.value = expression; _ } = in Callable { kind; implementation; overloads } in + let rec parse_broadcast expression = + match expression with + | { + Node.value = + Expression.Call + { + callee; + arguments = + [ + { + Call.Argument.name = None; + value = { Node.value = Expression.Tuple arguments; _ }; + _; + }; + ]; + }; + _; + } + when name_is ~name:"pyre_extensions.Broadcast.__getitem__" callee -> ( + match arguments with + | [ts1; ts2] -> ( + let parse_ts ts = + match parse_broadcast ts with + | None -> ( + match ts with + | { Node.value = Expression.List elements; _ } -> + let concrete = List.map elements ~f:create_logic in + Record.OrderedTypes.Group (Concrete concrete) |> Option.some + | _ -> + create_logic ts + |> substitute_ordered_types + >>| fun record -> Record.OrderedTypes.Group record ) + | Some broadcast -> Some broadcast + in + let ts1 = parse_ts ts1 in + let ts2 = parse_ts ts2 in + match ts1, ts2 with + | Some ts1, Some ts2 -> Record.OrderedTypes.Broadcast (ts1, ts2) |> Option.some + | _, _ -> None ) + | _ -> None ) + | _ -> None + in + let create_parametric ~base ~argument = let parametric name = let parameters = let parse_parameter = function | { Node.value = Expression.List elements; _ } -> let concrete = List.map elements ~f:create_logic in - Record.Parameter.Group (Concrete concrete) + Record.Parameter.VariadicExpression (Group (Concrete concrete)) | element -> ( - let parsed = create_logic element in - match substitute_ordered_types parsed with - | Some ordered -> Record.Parameter.Group ordered - | None -> ( - match substitute_parameter_variadic parsed with - | Some variable -> CallableParameters (ParameterVariadicTypeVariable variable) - | _ -> Record.Parameter.Single parsed ) ) + let broadcast = parse_broadcast element in + match broadcast with + | Some broadcast -> Record.Parameter.VariadicExpression broadcast + | _ -> ( + let parsed = create_logic element in + match substitute_ordered_types parsed with + | Some ordered -> Record.Parameter.VariadicExpression (Group ordered) + | None -> ( + match substitute_parameter_variadic parsed with + | Some variable -> + CallableParameters (ParameterVariadicTypeVariable variable) + | _ -> Record.Parameter.Single parsed ) ) ) in match argument with | { Node.value = Expression.Tuple elements; _ } -> List.map elements ~f:parse_parameter @@ -2928,8 +3135,8 @@ let rec create_logic ~aliases ~variable_aliases { Node.value = expression; _ } = | Parametric { name = "typing.Tuple"; parameters } | Parametric { name = "tuple"; parameters } -> ( match parameters with - | [Single parameter; Group Any] -> Tuple (Unbounded parameter) - | [Group group] -> Tuple (Bounded group) + | [Single parameter; VariadicExpression (Group Any)] -> Tuple (Unbounded parameter) + | [VariadicExpression (Group group)] -> Tuple (Bounded group) | parameters -> Parameter.all_singles parameters >>| (fun singles -> Tuple (Bounded (Concrete singles))) @@ -3166,12 +3373,16 @@ let weaken_literals annotation = module OrderedTypes = struct include Record.OrderedTypes - type t = type_t record [@@deriving compare, eq, sexp, show, hash] + type t = type_t variadic_expression [@@deriving compare, eq, sexp, show, hash] + + type record_t = type_t record [@@deriving compare, eq, sexp, show, hash] type ordered_types_t = t let pp_concise = pp_concise ~pp_type + let pp_concise_record = pp_concise_record ~pp_type + module Concatenation = struct include Record.OrderedTypes.RecordConcatenate @@ -3189,29 +3400,38 @@ module OrderedTypes = struct let create ~variable ~mappers = { variable; mappers } - let rec replace_variable middle ~replacement = + let rec replace_variable middle ~replacement ~replace_variadic = match middle with - | { Middle.mappers = []; variable } -> replacement variable + | { Middle.mappers = []; variable = Variadic variable } -> replacement variable + | { Middle.mappers = []; variable = Expression variadic_expression } -> + Some (replace_variadic variadic_expression) | { Middle.mappers = head_mapper :: tail_mapper; _ } -> let inner = { middle with mappers = tail_mapper } in let apply concrete = Parametric { name = head_mapper; parameters = [Single concrete] } in - let handle_replaced = function - | Any -> Any - | Concrete concretes -> Concrete (List.map concretes ~f:apply) - | Concatenation concatenation -> - Concatenation (apply_mapping ~mapper:head_mapper concatenation) + let handle_replaced = + let f = function + | Any -> Group Any + | Concrete concretes -> Group (Concrete (List.map concretes ~f:apply)) + | Concatenation concatenation -> + Group (Concatenation (apply_mapping ~mapper:head_mapper concatenation)) + in + transform_variadic_expression ~f in - replace_variable inner ~replacement >>| handle_replaced + replace_variable inner ~replacement ~replace_variadic >>| handle_replaced let singleton_replace_variable middle ~replacement = let extract = function - | Some (Concrete [extracted]) -> extracted + | Some (Group (Concrete [extracted])) -> extracted | _ -> failwith "this was a singleton replace" in - replace_variable middle ~replacement:(fun _ -> Some (Concrete [replacement])) |> extract + replace_variable + middle + ~replacement:(fun _ -> Some (Group (Concrete [replacement]))) + ~replace_variadic:Fn.id + |> extract end let parse expression ~aliases = @@ -3230,17 +3450,22 @@ module OrderedTypes = struct let map_middle { middle; wrapping } ~f = { middle = f middle; wrapping } - let replace_variable { middle; wrapping } ~replacement = + let replace_variable { middle; wrapping } ~replacement ~replace_variadic = let merge ~inner:{ head; tail } ~outer:{ head = outer_head; tail = outer_tail } = { head = outer_head @ head; tail = tail @ outer_tail } in let actualize ~inner { head; tail } = head @ inner @ tail in - match Middle.replace_variable middle ~replacement with + match Middle.replace_variable middle ~replacement ~replace_variadic with | None -> None - | Some Any -> Some Any - | Some (Concrete inner) -> Some (Concrete (actualize ~inner wrapping)) - | Some (Concatenation { middle = inner_middle; wrapping = inner }) -> - Some (Concatenation { middle = inner_middle; wrapping = merge ~inner ~outer:wrapping }) + | Some expression -> + let transform = function + | Any -> Group Any + | Concrete inner -> Group (Concrete (actualize ~inner wrapping)) + | Concatenation { middle = inner_middle; wrapping = inner } -> + Group + (Concatenation { middle = inner_middle; wrapping = merge ~inner ~outer:wrapping }) + in + Some (transform_variadic_expression expression ~f:transform) let variable { middle = { Middle.variable; _ }; _ } = variable @@ -3267,25 +3492,20 @@ module OrderedTypes = struct None end - let union_upper_bound ordered = - match ordered with - | Concrete concretes -> union concretes - | Any -> Any - | Concatenation _ -> object_primitive + let rec union_upper_bound variadic_expression = + match variadic_expression with + | Group (Concrete concretes) -> union concretes + | Group Any -> Any + | Group (Concatenation _) -> object_primitive + | Broadcast (left, right) -> union [union_upper_bound left; union_upper_bound right] - let variable ordered_types = + let local_replace_variable ordered_types ~replacement ~replace_variadic = match ordered_types with | Concrete _ -> None | Any -> None - | Concatenation concatenation -> Some (Concatenation.variable concatenation) - - - let local_replace_variable ordered_types ~replacement = - match ordered_types with - | Concrete _ -> None - | Any -> None - | Concatenation concatenation -> Concatenation.replace_variable concatenation ~replacement + | Concatenation concatenation -> + Concatenation.replace_variable concatenation ~replacement ~replace_variadic end let split annotation = @@ -3298,7 +3518,7 @@ let split annotation = | Tuple tuple -> let parameters = match tuple with - | Bounded parameters -> [Group parameters] + | Bounded parameters -> [VariadicExpression (Group parameters)] | Unbounded parameter -> [Single parameter] in Primitive "tuple", parameters @@ -3951,12 +4171,13 @@ end = struct type t = type_t record [@@deriving compare, sexp] end) - let any = OrderedTypes.Any + let any = OrderedTypes.Group OrderedTypes.Any let self_reference variable = - OrderedTypes.Concatenation - (OrderedTypes.Concatenation.empty_wrap - { OrderedTypes.Concatenation.Middle.variable; mappers = [] }) + OrderedTypes.Group + (OrderedTypes.Concatenation + (OrderedTypes.Concatenation.empty_wrap + { OrderedTypes.Concatenation.Middle.variable = Variadic variable; mappers = [] })) let pair variable value = ListVariadicPair (variable, value) @@ -3976,50 +4197,119 @@ end = struct let namespace variable ~namespace = { variable with namespace } (* TODO(T45087986): Add more entries here as we add hosts for these variables *) - let rec local_collect = function - | Tuple (Bounded bounded) -> OrderedTypes.variable bounded |> Option.to_list + let local_collect annotation = + let rec collect_variadic_tree variadic = + match variadic with + | OrderedTypes.Concatenation + { + middle = { OrderedTypes.RecordConcatenate.Middle.variable = Variadic variable; _ }; + _; + } -> + [variable] + | Concatenation + { + middle = + { OrderedTypes.RecordConcatenate.Middle.variable = Expression expression; _ }; + _; + } -> + let rec collect_expression expression = + match expression with + | OrderedTypes.Group variadic -> collect_variadic_tree variadic + | Broadcast (left, right) -> collect_expression left @ collect_expression right + in + collect_expression expression + | _ -> [] + in + match annotation with + | Tuple (Bounded bounded) -> collect_variadic_tree bounded | Callable { implementation; overloads; _ } -> let map = function | { parameters = Defined parameters; _ } -> let collect_variadic = function | Callable.Parameter.Variable (Concatenation concatenation) -> - Some (OrderedTypes.Concatenation.variable concatenation) - | _ -> None + collect_variadic_tree (OrderedTypes.Concatenation concatenation) + | _ -> [] in - List.filter_map parameters ~f:collect_variadic + List.concat_map parameters ~f:collect_variadic | _ -> [] in implementation :: overloads |> List.concat_map ~f:map | Parametric { parameters; _ } -> let collect = function - | Record.Parameter.Group ordered -> OrderedTypes.variable ordered |> Option.to_list + | Parameter.VariadicExpression expression -> + let rec collect_variadic = function + | OrderedTypes.Broadcast (left, right) -> + collect_variadic left @ collect_variadic right + | Group ordered -> collect_variadic_tree ordered + in + collect_variadic expression | CallableParameters _ | Single _ -> [] in List.concat_map parameters ~f:collect - | IntExpression polynomial -> - List.concat_map polynomial ~f:(fun { variables; _ } -> - List.concat_map variables ~f:(fun { variable; _ } -> - match variable with - | Variadic (_, variadic) -> [OrderedTypes.Concatenation.variable variadic] - | Divide (dividend, quotient) -> - local_collect (IntExpression dividend) - @ local_collect (IntExpression quotient) - | _ -> [])) - |> List.dedup_and_sort ~compare | _ -> [] - let rec local_replace replacement = function - | Tuple (Bounded bounded) -> - OrderedTypes.local_replace_variable bounded ~replacement - >>| fun ordered_types -> Tuple (Bounded ordered_types) + let rec local_replace replacement annotation = + let rec replace_variadic = function + | OrderedTypes.Broadcast (left, right) -> ( + let left = replace_variadic left in + let right = replace_variadic right in + match left, right with + | OrderedTypes.Group (Concrete left), Group (Concrete right) -> + let are_parameters_resolved = + (* TODO: Check if it is an IntExpression with unresolved variables *) + List.for_all ~f:(fun annotation -> + match annotation with + | _ -> true) + in + if are_parameters_resolved left && are_parameters_resolved right then + let zipped, remainder = + List.zip_with_remainder (List.rev left) (List.rev right) + in + let broadcast = function + | Literal (Integer left), Literal (Integer right) -> + if left = right || right = 1 then + Literal (Integer left) + else if left = 1 then + Literal (Integer right) + else + Any + | Any, _ + | _, Any -> + Any + | _, Primitive "int" + | Primitive "int", _ -> + integer + | _, _ -> Any + in + let main_broadcast = List.map zipped ~f:broadcast in + let prefix = + match remainder with + | None -> [] + | Some (First prefix) + | Some (Second prefix) -> + prefix + in + Group (Concrete (List.rev (main_broadcast @ prefix))) + else + Broadcast (Group (Concrete left), Group (Concrete right)) + | _, _ -> Broadcast (left, right) ) + | Group ordered -> ( + match OrderedTypes.local_replace_variable ordered ~replacement ~replace_variadic with + | Some group -> group + | None -> Group ordered ) + in + match annotation with + | Tuple (Bounded bounded) -> ( + match OrderedTypes.local_replace_variable bounded ~replacement ~replace_variadic with + | Some (Group ordered_types) -> Some (Tuple (Bounded ordered_types)) + | _ -> None ) | Parametric { name; parameters } -> let replace = function - | Record.Parameter.Group ordered -> - OrderedTypes.local_replace_variable ordered ~replacement - >>| fun group -> Record.Parameter.Group group + | Parameter.VariadicExpression expression -> + Some (Parameter.VariadicExpression (replace_variadic expression)) | CallableParameters _ | Single _ -> None @@ -4037,16 +4327,24 @@ end = struct let replace_variadic = function | Callable.Parameter.Variable (Concatenation concatenation) -> let encode_ordered_types_into_parameters = function - | OrderedTypes.Any -> [Callable.Parameter.Variable (Concrete Any)] - | Concrete concretes -> - let make_anonymous annotation = - Callable.Parameter.PositionalOnly - { index = 0; annotation; default = false } - in - List.map concretes ~f:make_anonymous - | Concatenation concatenation -> [Variable (Concatenation concatenation)] + | OrderedTypes.Group z -> ( + match z with + | OrderedTypes.Any -> [Callable.Parameter.Variable (Concrete Any)] + | Concrete concretes -> + let make_anonymous annotation = + Callable.Parameter.PositionalOnly + { index = 0; annotation; default = false } + in + List.map concretes ~f:make_anonymous + | Concatenation concatenation -> + [Variable (Concatenation concatenation)] ) + | OrderedTypes.Broadcast _ -> [] + (*TODO: Not sure about Broadcast in Callables?*) in - OrderedTypes.Concatenation.replace_variable concatenation ~replacement + OrderedTypes.Concatenation.replace_variable + concatenation + ~replacement + ~replace_variadic >>| encode_ordered_types_into_parameters |> Option.value ~default:[Callable.Parameter.Variable (Concatenation concatenation)] @@ -4067,12 +4365,12 @@ end = struct |> Option.some | IntExpression polynomial -> let replace_concatenation variadic ~operation = - let replace_variadic group ~operation = + let replace_variadic_group group ~operation = match operation, group with - | _, OrderedTypes.Any -> Any - | Monomial.Length, Concrete list_types -> + | _, OrderedTypes.Group OrderedTypes.Any -> Any + | Monomial.Length, OrderedTypes.Group (Concrete list_types) -> IntExpression (Polynomial.create_from_int (List.length list_types)) - | Monomial.Product, Concrete list_types -> + | Monomial.Product, OrderedTypes.Group (Concrete list_types) -> let identity_polynomial = Polynomial.create_from_int 1 in List.fold list_types @@ -4080,11 +4378,14 @@ end = struct ~f: (merge_int_expressions ~operation:(Polynomial.multiply ~compare_t:T.compare)) - | _, Concatenation variadic -> + | _, OrderedTypes.Group (Concatenation variadic) -> IntExpression (Polynomial.create_from_variadic variadic ~operation) + | _, OrderedTypes.Broadcast _ -> + (*TODO After redefining Variadic in Polynomial*) + IntExpression (Polynomial.create_from_int (-999)) in - OrderedTypes.Concatenation.replace_variable variadic ~replacement - >>| replace_variadic ~operation + OrderedTypes.Concatenation.replace_variable variadic ~replacement ~replace_variadic + >>| replace_variadic_group ~operation in local_replace_polynomial polynomial @@ -4391,13 +4692,13 @@ end = struct let coalesce_if_all_single parameters = Parameter.all_singles parameters - >>| (fun coalesced -> [Parameter.Group (Concrete coalesced)]) + >>| (fun coalesced -> [Parameter.VariadicExpression (Group (Concrete coalesced))]) |> Option.value ~default:parameters let correct_concrete_group_into_parameters ~variable parameter = match variable, parameter with - | ParameterVariadic _, Parameter.Group (Concrete group) -> + | ParameterVariadic _, Parameter.VariadicExpression (Group (Concrete group)) -> Parameter.CallableParameters (Defined (Callable.prepend_anonymous_parameters ~head:group ~tail:[])) | _, other -> other @@ -4452,7 +4753,7 @@ end = struct let to_parameter = function | Unary variable -> Parameter.Single (Unary.self_reference variable) - | ListVariadic variable -> Parameter.Group (Variadic.List.self_reference variable) + | ListVariadic variable -> Parameter.VariadicExpression (Variadic.List.self_reference variable) | ParameterVariadic variable -> Parameter.CallableParameters (Variadic.Parameters.self_reference variable) end diff --git a/analysis/type.mli b/analysis/type.mli index fae0707f3d0..0fe6644269f 100644 --- a/analysis/type.mli +++ b/analysis/type.mli @@ -61,9 +61,14 @@ module Record : sig [@@deriving compare, eq, sexp, show, hash] end - module OrderedTypes : sig + module rec OrderedTypes : sig module RecordConcatenate : sig module Middle : sig + type 'annotation variable = + | Variadic of 'annotation Variable.RecordVariadic.RecordList.record + | Expression of 'annotation OrderedTypes.variadic_expression + [@@deriving compare, eq, sexp, show, hash] + type 'annotation t [@@deriving compare, eq, sexp, show, hash] end @@ -76,7 +81,18 @@ module Record : sig | Concatenation of ('annotation RecordConcatenate.Middle.t, 'annotation) RecordConcatenate.t [@@deriving compare, eq, sexp, show, hash] + type 'annotation variadic_expression = + | Group of 'annotation record + | Broadcast of 'annotation variadic_expression * 'annotation variadic_expression + [@@deriving compare, eq, sexp, show, hash] + val pp_concise + : Format.formatter -> + 'a variadic_expression -> + pp_type:(Format.formatter -> 'a -> unit) -> + unit + + val pp_concise_record : Format.formatter -> 'a record -> pp_type:(Format.formatter -> 'a -> unit) -> @@ -144,8 +160,9 @@ module Record : sig module Parameter : sig type 'annotation record = | Single of 'annotation - | Group of 'annotation OrderedTypes.record + | VariadicExpression of 'annotation OrderedTypes.variadic_expression | CallableParameters of 'annotation Callable.record_parameters + [@@deriving compare, eq, sexp, show, hash] end module TypedDictionary : sig @@ -602,12 +619,16 @@ module OrderedTypes : sig include Record.OrderedTypes end - type t = type_t record [@@deriving compare, eq, sexp, show, hash] + type t = type_t variadic_expression [@@deriving compare, eq, sexp, show, hash] + + type record_t = type_t record [@@deriving compare, eq, sexp, show, hash] type ordered_types_t = t val pp_concise : Format.formatter -> t -> unit + val pp_concise_record : Format.formatter -> record_t -> unit + module Concatenation : sig include module type of struct include Record.OrderedTypes.RecordConcatenate @@ -618,16 +639,11 @@ module OrderedTypes : sig include Record.OrderedTypes.RecordConcatenate.Middle end - val unwrap_if_bare - : type_t t -> - type_t Record.Variable.RecordVariadic.RecordList.record option + val unwrap_if_bare : type_t t -> type_t variable option - val create_bare : type_t Record.Variable.RecordVariadic.RecordList.record -> type_t t + val create_bare : type_t variable -> type_t t - val create - : variable:type_t Record.Variable.RecordVariadic.RecordList.record -> - mappers:Identifier.t list -> - type_t t + val create : variable:type_t variable -> mappers:Identifier.t list -> type_t t val singleton_replace_variable : type_t t -> replacement:type_t -> type_t end @@ -643,6 +659,7 @@ module OrderedTypes : sig : (type_t Middle.t, type_t) t -> replacement: (type_t Record.Variable.RecordVariadic.RecordList.record -> ordered_types_t option) -> + replace_variadic:(type_t variadic_expression -> type_t variadic_expression) -> ordered_types_t option val head : ('middle, 'outer) t -> 'outer list @@ -653,9 +670,7 @@ module OrderedTypes : sig val unwrap_if_only_middle : ('middle, 'outer) t -> 'middle option - val variable - : (type_t Middle.t, 'outer) t -> - type_t Record.Variable.RecordVariadic.RecordList.record + val variable : (type_t Middle.t, 'outer) t -> type_t Middle.variable val expression : (type_t Middle.t, type_t) t -> Expression.t @@ -680,6 +695,13 @@ module OrderedTypes : sig (* Concatenation is only defined for certain members *) val concatenate : left:t -> right:t -> t option + + val concatenate_record : left:record_t -> right:record_t -> record_t option + + val transform_variadic_expression + : 'a variadic_expression -> + f:('a record -> 'a variadic_expression) -> + 'a variadic_expression end val split : t -> t * Parameter.t list diff --git a/analysis/typeCheck.ml b/analysis/typeCheck.ml index 4cde1c55a60..1b10f4db704 100644 --- a/analysis/typeCheck.ml +++ b/analysis/typeCheck.ml @@ -525,14 +525,8 @@ module State (Context : Context) = struct | Some [] -> parent_type | Some variables -> - let variables = - List.map variables ~f:(function - | Unary variable -> Type.Parameter.Single (Type.Variable variable) - | ListVariadic variadic -> Group (Type.Variable.Variadic.List.self_reference variadic) - | ParameterVariadic parameters -> - CallableParameters (Type.Variable.Variadic.Parameters.self_reference parameters)) - in - Type.parametric parent_name variables + let variables = List.map variables ~f:Type.Variable.to_parameter in + Type.Parametric { name = parent_name; parameters = variables } | exception _ -> parent_type diff --git a/analysis/typeConstraints.ml b/analysis/typeConstraints.ml index 84f687f059e..ce2a02c40b1 100644 --- a/analysis/typeConstraints.ml +++ b/analysis/typeConstraints.ml @@ -121,20 +121,32 @@ let exists_in_bounds { unaries; callable_parameters; list_variadics; _ } ~variab | _ -> false in let exists_in_list_variadic_interval_bounds interval = - let exists = function + let exists_in_record = function | Type.OrderedTypes.Concrete types -> List.exists types ~f:contains_variable | Concatenation concatenation -> let contains = List.exists ~f:contains_variable in let in_head () = Type.OrderedTypes.Concatenation.head concatenation |> contains in let in_middle () = - Type.OrderedTypes.Concatenation.variable concatenation - |> (fun variable -> Type.Variable.ListVariadic variable) - |> List.mem variables ~equal:Type.Variable.equal + let variadics = + Type.Variable.GlobalTransforms.ListVariadic.collect_all + (Type.Parametric + { + name = ""; + parameters = + [Type.Parameter.VariadicExpression (Group (Concatenation concatenation))]; + }) + in + List.exists variadics ~f:(fun x -> + List.mem variables (Type.Variable.ListVariadic x) ~equal:Type.Variable.equal) in let in_tail () = Type.OrderedTypes.Concatenation.tail concatenation |> contains in in_head () || in_middle () || in_tail () | _ -> false in + let rec exists = function + | Type.OrderedTypes.Group variadic -> exists_in_record variadic + | Type.OrderedTypes.Broadcast (left, right) -> exists left || exists right + in match interval with | NoBounds -> false | OnlyLowerBound bound @@ -232,18 +244,31 @@ module Solution = struct ParameterVariable.Map.find callable_parameters - let instantiate_ordered_types solution = function - | Type.OrderedTypes.Concrete concretes -> - List.map concretes ~f:(instantiate solution) - |> fun concretes -> Type.OrderedTypes.Concrete concretes - | Any -> Any - | Concatenation concatenation -> - let mapped = - Type.OrderedTypes.Concatenation.map_head_and_tail concatenation ~f:(instantiate solution) - in - let replacement = instantiate_single_list_variadic_variable solution in - Type.OrderedTypes.Concatenation.replace_variable mapped ~replacement - |> Option.value ~default:(Type.OrderedTypes.Concatenation mapped) + let instantiate_ordered_types solution ordered = + let rec instantiate_variadic expression = + match expression with + | Type.OrderedTypes.Concrete concretes -> + let concretes = + List.map concretes ~f:(instantiate solution) + |> fun concretes -> Type.OrderedTypes.Concrete concretes + in + Type.OrderedTypes.Group concretes + | Any -> Type.OrderedTypes.Group Any + | Concatenation concatenation -> + let record = + let mapped = + Type.OrderedTypes.Concatenation.map_head_and_tail + concatenation + ~f:(instantiate solution) + in + let replacement = instantiate_single_list_variadic_variable solution in + Type.OrderedTypes.Concatenation.replace_variable mapped ~replacement ~replace_variadic + in + record |> Option.value ~default:(Type.OrderedTypes.Group expression) + and replace_variadic = + Type.OrderedTypes.transform_variadic_expression ~f:instantiate_variadic + in + replace_variadic ordered let instantiate_callable_parameters solution parameters = @@ -557,22 +582,32 @@ module OrderedConstraints (Order : OrderType) = struct let less_or_equal order ~left ~right = + let less_or_equal_record left right = + if Type.OrderedTypes.equal_record_t left right then + true + else + match left, right with + | _, Any + | Any, _ -> + true + | Concatenation _, _ + | _, Concatenation _ -> + false + | Concrete upper_bounds, Concrete lower_bounds -> ( + match List.zip upper_bounds lower_bounds with + | Ok bounds -> + List.for_all bounds ~f:(fun (left, right) -> + Order.always_less_or_equal order ~left ~right) + | _ -> false ) + in if Type.OrderedTypes.equal left right then true else match left, right with - | _, Any - | Any, _ -> - true - | Concatenation _, _ - | _, Concatenation _ -> + | Type.OrderedTypes.Group left, Group right -> less_or_equal_record left right + | Broadcast _, _ + | _, Broadcast _ -> false - | Concrete upper_bounds, Concrete lower_bounds -> ( - match List.zip upper_bounds lower_bounds with - | Ok bounds -> - List.for_all bounds ~f:(fun (left, right) -> - Order.always_less_or_equal order ~left ~right) - | _ -> false ) let narrowest_valid_value interval ~order ~variable:_ = @@ -595,11 +630,12 @@ module OrderedConstraints (Order : OrderType) = struct Some right else match left, right with - | Concrete left, Concrete right -> ( + | Group (Concrete left), Group (Concrete right) -> ( match List.zip left right with | Ok zipped -> List.map zipped ~f:(fun (left, right) -> Order.meet order left right) - |> fun concrete_list -> Some (Type.OrderedTypes.Concrete concrete_list) + |> fun concrete_list -> + Some (Type.OrderedTypes.Group (Type.OrderedTypes.Concrete concrete_list)) | _ -> None ) | _ -> None in @@ -612,11 +648,12 @@ module OrderedConstraints (Order : OrderType) = struct Some left else match left, right with - | Concrete left, Concrete right -> ( + | Group (Concrete left), Group (Concrete right) -> ( match List.zip left right with | Ok zipped -> List.map zipped ~f:(fun (left, right) -> Order.join order left right) - |> fun concrete_list -> Some (Type.OrderedTypes.Concrete concrete_list) + |> fun concrete_list -> + Some (Type.OrderedTypes.Group (Type.OrderedTypes.Concrete concrete_list)) | _ -> None ) | _ -> None in @@ -679,20 +716,23 @@ module OrderedConstraints (Order : OrderType) = struct | OnlyLowerBound lower -> [lower] | BothBounds { upper; lower } -> [upper; lower] in - let extract = function - | Type.OrderedTypes.Any -> [] - | Concrete types -> List.concat_map types ~f:Type.Variable.all_free_variables - | Concatenation concatenation -> + let rec extract = function + | Type.OrderedTypes.Broadcast (left, right) -> extract left @ extract right + | Type.OrderedTypes.Group Any -> [] + | Type.OrderedTypes.Group (Concrete types) -> + List.concat_map types ~f:Type.Variable.all_free_variables + | Type.OrderedTypes.Group (Concatenation concatenation) -> ( let outer = Type.OrderedTypes.Concatenation.head concatenation @ Type.OrderedTypes.Concatenation.tail concatenation |> List.concat_map ~f:Type.Variable.all_free_variables in let inner = Type.OrderedTypes.Concatenation.variable concatenation in - if Type.Variable.Variadic.List.is_free inner then - ListVariadic inner :: outer - else - outer + match inner with + | Variadic variadic when Type.Variable.Variadic.List.is_free variadic -> + ListVariadic variadic :: outer + | Expression expression -> extract expression + | _ -> outer ) in List.concat_map bounds ~f:extract end @@ -767,7 +807,7 @@ module OrderedConstraints (Order : OrderType) = struct | Type.Variable.ListVariadic variable -> { solution with - list_variadics = optional_add list_variadics variable Type.OrderedTypes.Any; + list_variadics = optional_add list_variadics variable (Type.OrderedTypes.Group Any); } in Set.to_list have_fallbacks |> List.fold ~init:solution ~f:add_fallback diff --git a/analysis/typeOrder.ml b/analysis/typeOrder.ml index 33ec5ea682b..0ab4aa2c9b3 100644 --- a/analysis/typeOrder.ml +++ b/analysis/typeOrder.ml @@ -209,8 +209,8 @@ module OrderImplementation = struct let variables = variables target in let join_parameters (left, right, variable) = match left, right, variable with - | Type.Parameter.Group _, _, _ - | _, Type.Parameter.Group _, _ + | Type.Parameter.VariadicExpression _, _, _ + | _, Type.Parameter.VariadicExpression _, _ | _, _, Type.Variable.ListVariadic _ | CallableParameters _, _, _ | _, CallableParameters _, _ diff --git a/analysis/unannotatedGlobalEnvironment.ml b/analysis/unannotatedGlobalEnvironment.ml index 8c7d1b755df..43a1263ef40 100644 --- a/analysis/unannotatedGlobalEnvironment.ml +++ b/analysis/unannotatedGlobalEnvironment.ml @@ -645,7 +645,7 @@ let missing_builtin_classes, missing_typing_classes, missing_typing_extensions_c let single_unary_generic = [Type.parametric "typing.Generic" [Single (Variable (Type.Variable.Unary.create "typing._T"))]] in - let catch_all_generic = [Type.parametric "typing.Generic" [Group Any]] in + let catch_all_generic = [Type.parametric "typing.Generic" [VariadicExpression (Group Any)]] in let callable_body = [ Statement.Assign diff --git a/pyre_extensions/__init__.py b/pyre_extensions/__init__.py index 42f3a3dc887..975b097cf7e 100644 --- a/pyre_extensions/__init__.py +++ b/pyre_extensions/__init__.py @@ -99,6 +99,7 @@ def ListVariadic(name) -> object: _A = TypeVar("_A", bound=int) _B = TypeVar("_B", bound=int) _Ts = ListVariadic("_Ts") +_Ts1 = ListVariadic("_Ts1") class Add(Generic[_A, _B], int): @@ -119,3 +120,7 @@ class Length(Generic[_Ts], int): class Product(Generic[_Ts], int): pass + + +class Broadcast(Generic[_Ts, _Ts1], object): + pass diff --git a/test/test.ml b/test/test.ml index 19e2ca7313e..ce2e043017b 100644 --- a/test/test.ml +++ b/test/test.ml @@ -1396,8 +1396,10 @@ let typeshed_stubs ?(include_helper_builtins = true) () = class Multiply(Generic[_A, _B], int): pass class Divide(Generic[_A, _B], int): pass _Ts = ListVariadic("_Ts") + _Ts1 = ListVariadic("_Ts1") class Length(Generic[_Ts], int): pass class Product(Generic[_Ts], int): pass + class Broadcast(Generic[_Ts, _Ts1], object): pass |} ); ( "pyre_extensions/type_variable_operators.pyi",