Type Inference
2023-Jan-31
Dabbling with Haskell inspired me to learn about type inference. So this last week I spent some time reading through literature on type systems and implemented the Hindley-Milner algorithm for a toy expression language using Rust. In this post, I will give an implementation focused explanation of the algorithm.
1 Hindley-Milner
The type inference algorithm takes the AST as input and outputs the type for each expression in the tree. Much like a detective, who solves a case by collecting all the available evidence and then finding the relation between them to come to a conclusion, the algorithm creates a mapping between each expression and its expected type and then finds the types for all expressions using the relation between the generated mappings. Formally, the first step is termed as constraint generation and the latter is unification.
1.1 Constraint Generation
A constraint is a mapping between an expression and its expected type. To generate constraints for an AST, the algorithm recursively visits every expression in the tree and maps it to an expected type. Primitive expressions like numbers, booleans and identifiers, can be straightforwardly mapped to their respective types since any symbol parsed as, for example, a number is guaranteed to be of type number. It gets more interesting when generating constraints for nested expressions.
Constraints for nested expressions are generated in two steps:
- Each sub-expression is constrained to an expected type based on its value. For example, the sub-expression
1
in a conditional expression will be mapped to typeNumber
. - All sub-expressions and the root expression are mapped to their expected types based on their position in the AST. For example, if there is a sub-expression
1
in the condition part of conditional expression, it will be mapped to typeBool
.
1.1.1 Binary Add Expression
For the binary add expression, generate constraints for the left and right operands and then map each operand and the add expression to type number. For example, the following are the constraints generated for the incorrect expression 1 + false
.
1 = Number false = Bool 1 = Number false = Number 1 + false = Number
1.1.2 Conditional Expression
For conditional expressions, there are three expectations for its types:
- Condition should be of type boolean.
- Both the branches should be of same type1.
- Type of conditional expression should be of same type as its branches.
To express these expectations in terms of constraints, start by generating constraints for the condition and the two branches. Then, map the condition to type boolean and the conditional expression to then and else branch to denote expectations 2 and 3 from above. The constraints generated for the expression if true {1} else {2}
are:
true = Bool 1 = Number 2 = Number true = Bool if true {1} else {2} = 1 if true {1} else {2} = 2
1.1.3 Functions
A function definition is constrained as an arrow type with parameters as domain and body as range. As you may be familiar by now, start by generating constraints for parameters and body to determine their types. Then, map the function definition to an arrow type with parameters as domain and body as range. For the expression lambda(x) x + 2
, the constraints are:
x = x 2 = Number x = Number 2 = Number x + 2 = Number (lambda(x) x + 2) = x -> x + 2
Similarly, a function call is also constrained as an arrow type with arguments as domain and the call expression as range because the expected type of a call expression is the type returned by calling the function with given arguments. So, generate constraints for the calling function, arguments and then map the function expression to an arrow type with arguments as domain and the call expression as range. For example, calling the function defined above with argument 10
will generate:
x = x 2 = Number x = Number 2 = Number x + 2 = Number (lambda(x) x + 2) = x -> x + 2 10 = Number (lambda(x) x + 2) = 10 -> Call<(lambda(x) x + 2)(10)>
1.2 Unification
The meat of type inference happens during unification. In this step, the algorithm iterates through the list of constraints and outputs either a type error, if any, or the type for each expression.
1.2.1 Substitution
The core idea of unification is a simple concept called substitution. Assume we have the following set of constraints for a function definition (lambda(x) x + 2)
.
1. x = Number 2. 2 = Number 3. x + 2 = Number 4. (lambda(x) x + 2) = x -> x + 2
After unification, the RHS of constraint #4 will be substituted with types asserted by previous constraints. The parameter x
will be substituted with type Number
as stated by constraint #1 and the body x + 2
will also be substituted with type Number
as stated by constraint #3.
(lambda(x) x + 2) :: Number -> Number
1.2.2 Occurs Check
We substitute a term A with term B to ensure that expressions progressively get mapped to a more specific type. But if term B contains term A, the replacement is redundant and can lead to infinite unifications. To prevent this, ensure that term B does not contain term A anywhere in its nested structure before each substitution.
1.2.3 Algorithm
The goal of unification is to make the required substitution for each constraint. The algorithm to accomplish this can be better explained through psuedocode:
substitution list = [] for LHS and RHS in each constraint: if LHS == RHS: continue else if LHS is not a type: substitute all occurrences of LHS with RHS in both constraint and substitution list add the substitution LHS -> RHS to substitution list else if RHS is not a type: substitute all occurrences of RHS with LHS in both constraint and substitution list add the substitution RHS -> LHS to substitution list else if both LHS and RHS are arrow types: create a new constraint mapping domain of LHS with domain of RHS create a new constraint mapping range of LHS with range of RHS add both constraints to constraint list else: it's a type error as LHS and RHS cannot be unified
- Notes
- When replacing all occurrences of LHS with RHS or vice versa, ensure that replacement occurs in both LHS and RHS of constraint and substitution list.
- The implementation can take advantage of nested structure of ASTs and implement unification and all supporting functions recursively.
2 Conclusion
It took me about a week to understand and implement this algorithm. In the end, it felt magical to programmatically infer all the types for an untyped language.
Refer to the following resources to learn more about this algorithm:
- Chapter 30 in Programming Languages: Application and Interpretation (first edition) for step by step examples of constraint generation and unification.
- Chapter 15.3.2 in Programming Languages: Application and Interpretation (second edition) for an explanation of the algorithm using Typed Racket.
- Unification by Eli Bendersky talks a bit about efficiency.
- type-unify.rkt for a complete implementation in Typed Racket.
Footnotes:
Statically typed languages have this requirement because the algorithm cannot determine which branch will get executed during runtime but it has to assign a type for the entire expression. In contrast, dynamically typed languages are more flexible as they can assign the type of conditional expression to the type of the branch that will get executed.