Type Inference

2023-Jan-31

Dabbling with Haskell inspired me to learn about type inference. So this last week I spent some time reading through literature on type systems and implemented the Hindley-Milner algorithm for a toy expression language using Rust. In this post, I will give an implementation focused explanation of the algorithm.

1 Hindley-Milner

The type inference algorithm takes the AST as input and outputs the type for each expression in the tree. Much like a detective, who solves a case by collecting all the available evidence and then finding the relation between them to come to a conclusion, the algorithm creates a mapping between each expression and its expected type and then finds the types for all expressions using the relation between the generated mappings. Formally, the first step is termed as constraint generation and the latter is unification.

1.1 Constraint Generation

A constraint is a mapping between an expression and its expected type. To generate constraints for an AST, the algorithm recursively visits every expression in the tree and maps it to an expected type. Primitive expressions like numbers, booleans and identifiers, can be straightforwardly mapped to their respective types since any symbol parsed as, for example, a number is guaranteed to be of type number. It gets more interesting when generating constraints for nested expressions.

Constraints for nested expressions are generated in two steps:

  1. Each sub-expression is constrained to an expected type based on its value. For example, the sub-expression 1 in a conditional expression will be mapped to type Number.
  2. All sub-expressions and the root expression are mapped to their expected types based on their position in the AST. For example, if there is a sub-expression 1 in the condition part of conditional expression, it will be mapped to type Bool.

1.1.1 Binary Add Expression

For the binary add expression, generate constraints for the left and right operands and then map each operand and the add expression to type number. For example, the following are the constraints generated for the incorrect expression 1 + false.

1         = Number
false     = Bool
1         = Number
false     = Number
1 + false = Number

1.1.2 Conditional Expression

For conditional expressions, there are three expectations for its types:

  1. Condition should be of type boolean.
  2. Both the branches should be of same type1.
  3. Type of conditional expression should be of same type as its branches.

To express these expectations in terms of constraints, start by generating constraints for the condition and the two branches. Then, map the condition to type boolean and the conditional expression to then and else branch to denote expectations 2 and 3 from above. The constraints generated for the expression if true {1} else {2} are:

true                 = Bool
1                    = Number
2                    = Number
true                 = Bool
if true {1} else {2} = 1
if true {1} else {2} = 2

1.1.3 Functions

A function definition is constrained as an arrow type with parameters as domain and body as range. As you may be familiar by now, start by generating constraints for parameters and body to determine their types. Then, map the function definition to an arrow type with parameters as domain and body as range. For the expression lambda(x) x + 2, the constraints are:

x                 = x
2                 = Number
x                 = Number
2                 = Number
x + 2             = Number
(lambda(x) x + 2) = x -> x + 2

Similarly, a function call is also constrained as an arrow type with arguments as domain and the call expression as range because the expected type of a call expression is the type returned by calling the function with given arguments. So, generate constraints for the calling function, arguments and then map the function expression to an arrow type with arguments as domain and the call expression as range. For example, calling the function defined above with argument 10 will generate:

x                 = x
2                 = Number
x                 = Number
2                 = Number
x + 2             = Number
(lambda(x) x + 2) = x -> x + 2
10                = Number
(lambda(x) x + 2) = 10 -> Call<(lambda(x) x + 2)(10)>

1.2 Unification

The meat of type inference happens during unification. In this step, the algorithm iterates through the list of constraints and outputs either a type error, if any, or the type for each expression.

1.2.1 Substitution

The core idea of unification is a simple concept called substitution. Assume we have the following set of constraints for a function definition (lambda(x) x + 2).

1. x                 = Number
2. 2                 = Number
3. x + 2             = Number
4. (lambda(x) x + 2) = x -> x + 2

After unification, the RHS of constraint #4 will be substituted with types asserted by previous constraints. The parameter x will be substituted with type Number as stated by constraint #1 and the body x + 2 will also be substituted with type Number as stated by constraint #3.

(lambda(x) x + 2) :: Number -> Number

1.2.2 Occurs Check

We substitute a term A with term B to ensure that expressions progressively get mapped to a more specific type. But if term B contains term A, the replacement is redundant and can lead to infinite unifications. To prevent this, ensure that term B does not contain term A anywhere in its nested structure before each substitution.

1.2.3 Algorithm

The goal of unification is to make the required substitution for each constraint. The algorithm to accomplish this can be better explained through psuedocode:

substitution list = []
for LHS and RHS in each constraint:
    if LHS == RHS:
        continue
    else if LHS is not a type:
        substitute all occurrences of LHS with RHS in both constraint and substitution list
        add the substitution LHS -> RHS to substitution list
    else if RHS is not a type:
        substitute all occurrences of RHS with LHS in both constraint and substitution list
        add the substitution RHS -> LHS to substitution list
    else if both LHS and RHS are arrow types:
        create a new constraint mapping domain of LHS with domain of RHS
        create a new constraint mapping range of LHS with range of RHS
        add both constraints to constraint list 
    else:
        it's a type error as LHS and RHS cannot be unified
  • Notes
    1. When replacing all occurrences of LHS with RHS or vice versa, ensure that replacement occurs in both LHS and RHS of constraint and substitution list.
    2. The implementation can take advantage of nested structure of ASTs and implement unification and all supporting functions recursively.

2 Conclusion

It took me about a week to understand and implement this algorithm. In the end, it felt magical to programmatically infer all the types for an untyped language.

Refer to the following resources to learn more about this algorithm:

Footnotes:

1

Statically typed languages have this requirement because the algorithm cannot determine which branch will get executed during runtime but it has to assign a type for the entire expression. In contrast, dynamically typed languages are more flexible as they can assign the type of conditional expression to the type of the branch that will get executed.