Chapter 20: Functional Programs (Functional Programs)
소개
Phase 6의 여정을 복습하자:
Chapter 17: Pattern Matching Theory에서는 패턴 매칭의 이론적 기초를 다뤘다:
- Decision tree 알고리즘 (Maranget 2008)
- Pattern matrix와 specialization/defaulting 연산
- Exhaustiveness checking과 unreachable case detection
- 컴파일 시간에 패턴을 분석하여 최적의 decision tree 생성
Chapter 18: List Operations에서는 패턴 매칭이 작동할 데이터 구조를 구현했다:
!funlang.list<T>parameterized type으로 타입 안전한 리스트 표현funlang.nil과funlang.consoperations으로 리스트 생성- TypeConverter로 tagged union
!llvm.struct<(i32, ptr)>변환 - NilOpLowering과 ConsOpLowering patterns로 LLVM dialect 생성
Chapter 19: Match Compilation에서는 모든 것을 종합했다:
funlang.matchoperation으로 패턴 매칭 표현- Multi-stage lowering: FunLang → SCF → CF → LLVM
- IRMapping으로 block argument remapping
- 실행 가능한 코드 생성
Chapter 20에서는 이 모든 것을 사용하여 실제 함수형 프로그램을 작성한다.
Phase 6의 최종 목표: 완전한 함수형 프로그래밍
Phase 4에서 우리는 클로저를 구현했다:
// Phase 4: 클로저
let makeAdder n = fun x -> x + n
let add5 = makeAdder 5
let result = add5 10 // 15
Phase 5에서 우리는 커스텀 FunLang dialect를 만들었다:
// Phase 5: FunLang operations
%closure = funlang.closure @add_impl(%n) : (i32) -> ((i32) -> i32)
%result = funlang.apply %closure(%x) : ((i32) -> i32, i32) -> i32
Phase 6에서 우리는 리스트와 패턴 매칭을 구현했다:
// Phase 6: Lists and pattern matching
%list = funlang.cons %head, %tail : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%result = funlang.match %list : !funlang.list<i32> -> i32 {
^nil:
funlang.yield %zero : i32
^cons(%h: i32, %t: !funlang.list<i32>):
funlang.yield %h : i32
}
이제 이 세 가지를 조합하면 강력한 함수형 프로그래밍이 가능하다:
// Phase 6 Complete: 클로저 + 리스트 + 패턴 매칭
let map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (map f tail)
let double x = x * 2
let result = map double [1, 2, 3] // [2, 4, 6]
Chapter 20의 목표: 실전 함수형 프로그램
이 장을 마치면 다음과 같은 실제 함수형 프로그램을 컴파일하고 실행할 수 있다:
1. map: 리스트의 각 원소에 함수를 적용
let map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (map f tail)
map (fun x -> x * 2) [1, 2, 3] // [2, 4, 6]
2. filter: 조건을 만족하는 원소만 남기기
let filter pred lst =
match lst with
| [] -> []
| head :: tail ->
if pred head then
head :: filter pred tail
else
filter pred tail
filter (fun x -> x > 2) [1, 2, 3, 4] // [3, 4]
3. fold: 리스트를 하나의 값으로 축약
let fold f acc lst =
match lst with
| [] -> acc
| head :: tail -> fold f (f acc head) tail
fold (+) 0 [1, 2, 3, 4, 5] // 15
4. 조합: 복잡한 프로그램
// 제곱의 합: [1, 2, 3] -> 14
let sum_of_squares lst =
fold (+) 0 (map (fun x -> x * x) lst)
sum_of_squares [1, 2, 3] // 1 + 4 + 9 = 14
성공 기준: 완전한 컴파일 파이프라인
각 함수형 프로그램에 대해 다음을 보여준다:
- FunLang 소스 코드: F# 스타일의 함수형 문법
- FunLang dialect MLIR: 커스텀 operations 사용
- SCF dialect MLIR: 제어 흐름으로 변환
- LLVM dialect MLIR: 최종 lowering
- 실행 결과: JIT으로 실행하여 결과 확인
이것이 바로 “실전 컴파일러“다:
- 교과서의 toy 예제가 아니라 실제 사용 가능한 프로그램
- 모든 단계를 추적 가능하고 검증 가능
- Phase 7 (최적화)로 이어지는 기반
Chapter 20의 구성
Part 1: Map and Filter (이번 섹션)
- FunLang에서 리스트 구축하기
- map 함수: 소스, 컴파일, 실행
- filter 함수: 중첩 제어 흐름
- Helper 함수: length, append
Part 2: Fold and Complete Pipeline
- fold 함수: 일반적인 리스트 combinator
- 완전한 예제: sum_of_squares
- 성능 고려사항
- 완전한 컴파일러 통합
- Phase 6 요약과 Phase 7 미리보기
이 장을 마치면 Phase 6가 완료되며, Phase 7 (최적화)로 넘어갈 준비가 된다.
FunLang에서 리스트 구축하기
FunLang AST 확장: List Expressions
지금까지 우리는 MLIR operations로 리스트를 직접 구축했다:
// 직접 MLIR 작성
%nil = funlang.nil : !funlang.list<i32>
%three = arith.constant 3 : i32
%l1 = funlang.cons %three, %nil : (i32, !funlang.list<i32>) -> !funlang.list<i32>
하지만 사용자는 FunLang 언어로 작성하고 싶어한다:
// 사용자가 원하는 문법
let empty = []
let list = [1, 2, 3]
let consed = 1 :: [2, 3]
AST 확장이 필요하다.
FunLang AST Type Definition
Ast.fs에 리스트 표현식을 추가한다:
// Ast.fs
type Expr =
| Int of int
| Var of string
| Add of Expr * Expr
| Let of string * Expr * Expr
| If of Expr * Expr * Expr
| Fun of string * Expr // Phase 4: lambda
| App of Expr * Expr // Phase 4: application
// Phase 6: List expressions
| Nil // []
| Cons of Expr * Expr // head :: tail
| List of Expr list // [1, 2, 3] - syntactic sugar
| Match of Expr * (Pattern * Expr) list // match expr with cases
and Pattern =
| PVar of string // x (bind any value)
| PNil // [] (empty list)
| PCons of Pattern * Pattern // head :: tail
| PWild // _ (wildcard)
설계 결정:
Nil: Empty list[]는 zero-argument constructorCons: Binary operator::(head와 tail)List: List literal[1, 2, 3]는 syntactic sugar (연속된 Cons로 desugaring)Match: Pattern matching expression
List Literal Desugaring
List literal은 syntactic sugar다:
// 사용자 작성
[1, 2, 3]
// Desugaring
1 :: (2 :: (3 :: []))
// AST 표현
Cons(Int 1, Cons(Int 2, Cons(Int 3, Nil)))
Desugaring 함수:
// Parser.fs or Desugar.fs
let rec desugarList (exprs: Expr list) : Expr =
match exprs with
| [] -> Nil
| head :: tail -> Cons(head, desugarList tail)
// Usage
let ast = List [Int 1; Int 2; Int 3]
let desugared = desugarList [Int 1; Int 2; Int 3]
// Result: Cons(Int 1, Cons(Int 2, Cons(Int 3, Nil)))
왜 desugaring인가?
- 간단한 컴파일: 컴파일러는
Cons와Nil만 처리하면 된다 - 중복 제거:
Listliteral과Consoperator가 같은 representation을 공유 - 확장성: 새로운 syntactic sugar 추가 시 desugaring만 변경
컴파일러 통합: compileExpr 확장
Compiler.fs의 compileExpr 함수를 확장하여 리스트를 처리한다:
// Compiler.fs
let rec compileExpr (builder: OpBuilder) (expr: Expr) (symbolTable: Map<string, Value>) : Value =
match expr with
| Int n ->
let ty = builder.GetI32Type()
builder.CreateConstantInt(ty, n)
| Var name ->
symbolTable.[name]
| Add (left, right) ->
let lhs = compileExpr builder left symbolTable
let rhs = compileExpr builder right symbolTable
builder.CreateAddI(lhs, rhs)
// ... (Phase 3-4 cases)
// Phase 6: Nil case
| Nil ->
// funlang.nil : !funlang.list<T>
// Type inference: 주변 context에서 element type 결정
let elemTy = inferElementType expr // e.g., i32
let listTy = builder.GetListType(elemTy)
builder.CreateNil(listTy)
// Phase 6: Cons case
| Cons (head, tail) ->
// funlang.cons %head, %tail : (T, !funlang.list<T>) -> !funlang.list<T>
let headVal = compileExpr builder head symbolTable
let tailVal = compileExpr builder tail symbolTable
let headTy = headVal.GetType()
let listTy = builder.GetListType(headTy)
builder.CreateCons(headVal, tailVal, listTy)
// Phase 6: Match case (covered later in this chapter)
| Match (scrutinee, cases) ->
compileMatch builder scrutinee cases symbolTable
Type inference 예제:
// FunLang source
let list = 1 :: 2 :: []
// Type inference
// - 1 is i32, so head is i32
// - Cons expects (i32, !funlang.list<i32>)
// - [] must be !funlang.list<i32>
// Compiled MLIR
%c1 = arith.constant 1 : i32
%c2 = arith.constant 2 : i32
%nil = funlang.nil : !funlang.list<i32>
%tail = funlang.cons %c2, %nil : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%list = funlang.cons %c1, %tail : (i32, !funlang.list<i32>) -> !funlang.list<i32>
예제: 리스트 컴파일
Example 1: Empty list
// FunLang
let empty = []
Compiled MLIR:
func.func @example1() -> !funlang.list<i32> {
%empty = funlang.nil : !funlang.list<i32>
return %empty : !funlang.list<i32>
}
Example 2: Single element
// FunLang
let single = [42]
// Desugared
let single = 42 :: []
Compiled MLIR:
func.func @example2() -> !funlang.list<i32> {
%c42 = arith.constant 42 : i32
%nil = funlang.nil : !funlang.list<i32>
%single = funlang.cons %c42, %nil : (i32, !funlang.list<i32>) -> !funlang.list<i32>
return %single : !funlang.list<i32>
}
Example 3: Multiple elements
// FunLang
let list = [1, 2, 3]
// Desugared
let list = 1 :: (2 :: (3 :: []))
Compiled MLIR:
func.func @example3() -> !funlang.list<i32> {
// Build from inside out: 3 :: []
%c3 = arith.constant 3 : i32
%nil = funlang.nil : !funlang.list<i32>
%l3 = funlang.cons %c3, %nil : (i32, !funlang.list<i32>) -> !funlang.list<i32>
// 2 :: [3]
%c2 = arith.constant 2 : i32
%l2 = funlang.cons %c2, %l3 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
// 1 :: [2, 3]
%c1 = arith.constant 1 : i32
%l1 = funlang.cons %c1, %l2 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
return %l1 : !funlang.list<i32>
}
Example 4: Cons operator
// FunLang
let list = 1 :: 2 :: 3 :: []
Compiled MLIR (same as Example 3):
func.func @example4() -> !funlang.list<i32> {
%c3 = arith.constant 3 : i32
%nil = funlang.nil : !funlang.list<i32>
%l3 = funlang.cons %c3, %nil : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%c2 = arith.constant 2 : i32
%l2 = funlang.cons %c2, %l3 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%c1 = arith.constant 1 : i32
%l1 = funlang.cons %c1, %l2 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
return %l1 : !funlang.list<i32>
}
Type safety:
FunLang의 타입 시스템은 heterogeneous list를 방지한다:
// Type error: element type mismatch
let bad = [1, "hello", 3]
// Error: Expected i32, found string
MLIR type은 element type을 명시한다:
!funlang.list<i32>: 32비트 정수 리스트!funlang.list<f64>: 64비트 부동소수점 리스트!funlang.list<!funlang.closure<(i32) -> i32>>: 클로저 리스트 (고차 함수)
이제 우리는 리스트를 구축할 수 있다. 다음은 리스트를 조작하는 함수를 작성할 차례다.
map 함수: 리스트 변환
map의 개념
map은 함수형 프로그래밍의 가장 기본적인 고차 함수다:
// map의 타입
map : (a -> b) -> [a] -> [b]
// map의 의미
map f [x1, x2, ..., xn] = [f x1, f x2, ..., f xn]
예제:
let double x = x * 2
map double [1, 2, 3] // [2, 4, 6]
let inc x = x + 1
map inc [10, 20, 30] // [11, 21, 31]
map (fun x -> x * x) [1, 2, 3, 4] // [1, 4, 9, 16]
FunLang 소스 코드
map 함수를 FunLang으로 작성한다:
let rec map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (map f tail)
동작 원리:
- Base case: Empty list → return empty list
- Recursive case:
- Apply
ftohead→ transformed head - Recursively map over
tail - Cons the results
- Apply
실행 trace:
map double [1, 2, 3]
→ double 1 :: map double [2, 3]
→ 2 :: (double 2 :: map double [3])
→ 2 :: (4 :: (double 3 :: map double []))
→ 2 :: (4 :: (6 :: []))
→ [2, 4, 6]
FunLang AST 표현
FunLang AST로 표현하면:
// let rec map f lst = ...
Let("map",
Fun("f",
Fun("lst",
Match(Var "lst",
[ (PNil, Nil)
; (PCons(PVar "head", PVar "tail"),
Cons(App(Var "f", Var "head"),
App(App(Var "map", Var "f"), Var "tail")))
]))),
// ... body that uses map ...
)
구조 분석:
- Outer Let:
map정의를 scope에 바인딩 - Curried function:
f와lst두 개의 중첩 lambda - Match expression:
lst에 대한 패턴 매칭 - Patterns:
[](PNil)과head :: tail(PCons) - Recursive call:
map f tail에서map자기 자신 호출
컴파일된 MLIR: FunLang Dialect
compileExpr가 위 AST를 컴파일하면 다음 MLIR이 생성된다:
// map : (T -> U) -> !funlang.list<T> -> !funlang.list<U>
func.func @map(%f: !funlang.closure<(i32) -> i32>,
%lst: !funlang.list<i32>) -> !funlang.list<i32> {
// match lst with ...
%result = funlang.match %lst : !funlang.list<i32> -> !funlang.list<i32> {
// Case 1: [] -> []
^nil:
%empty = funlang.nil : !funlang.list<i32>
funlang.yield %empty : !funlang.list<i32>
// Case 2: head :: tail -> (f head) :: (map f tail)
^cons(%head: i32, %tail: !funlang.list<i32>):
// f head
%transformed = funlang.apply %f(%head) : (!funlang.closure<(i32) -> i32>, i32) -> i32
// map f tail (recursive call)
%mapped_tail = func.call @map(%f, %tail)
: (!funlang.closure<(i32) -> i32>, !funlang.list<i32>) -> !funlang.list<i32>
// transformed :: mapped_tail
%new_list = funlang.cons %transformed, %mapped_tail
: (i32, !funlang.list<i32>) -> !funlang.list<i32>
funlang.yield %new_list : !funlang.list<i32>
}
return %result : !funlang.list<i32>
}
핵심 포인트:
funlang.match: 리스트를 검사하는 control flowfunlang.apply: 클로저 간접 호출 (f head)func.call @map: 재귀 호출 (named function)funlang.cons: 결과 리스트 구축- Type safety: 모든 operations가 타입 정보를 유지
Lowering Stage 1: FunLang → SCF
FunLangToSCFPass가 실행되면 funlang.match가 scf.if로 lowering된다:
func.func @map(%f: !funlang.closure<(i32) -> i32>,
%lst: !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)> {
// Extract tag: lst->tag
%tag_ptr = llvm.getelementptr %lst[0, 0] : (!llvm.struct<(i32, ptr)>) -> !llvm.ptr
%tag = llvm.load %tag_ptr : !llvm.ptr -> i32
// Check if tag == 0 (Nil)
%c0 = arith.constant 0 : i32
%is_nil = arith.cmpi eq, %tag, %c0 : i32
// if (is_nil) then ... else ...
%result = scf.if %is_nil -> !llvm.struct<(i32, ptr)> {
// Nil case: return empty list
%nil_tag = arith.constant 0 : i32
%null_ptr = llvm.mlir.null : !llvm.ptr
%empty = llvm.mlir.undef : !llvm.struct<(i32, ptr)>
%empty1 = llvm.insertvalue %nil_tag, %empty[0] : !llvm.struct<(i32, ptr)>
%empty2 = llvm.insertvalue %null_ptr, %empty1[1] : !llvm.struct<(i32, ptr)>
scf.yield %empty2 : !llvm.struct<(i32, ptr)>
} else {
// Cons case: extract head and tail
%cons_tag = arith.constant 1 : i32
%payload_ptr = llvm.getelementptr %lst[0, 1] : (!llvm.struct<(i32, ptr)>) -> !llvm.ptr
%payload = llvm.load %payload_ptr : !llvm.ptr -> !llvm.ptr
// Cast payload to ConsCell: struct { head: i32, tail: list }
%head_ptr = llvm.getelementptr %payload[0, 0] : (!llvm.ptr) -> !llvm.ptr
%head = llvm.load %head_ptr : !llvm.ptr -> i32
%tail_ptr = llvm.getelementptr %payload[0, 1] : (!llvm.ptr) -> !llvm.ptr
%tail = llvm.load %tail_ptr : !llvm.ptr -> !llvm.struct<(i32, ptr)>
// Apply closure: f head
%transformed = funlang.apply %f(%head) : (!funlang.closure<(i32) -> i32>, i32) -> i32
// Recursive call: map f tail
%mapped_tail = func.call @map(%f, %tail)
: (!funlang.closure<(i32) -> i32>, !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)>
// Build cons cell: transformed :: mapped_tail
%cell_size = llvm.mlir.constant(16 : i64) : i64 // sizeof(ConsCell)
%cell = llvm.call @GC_malloc(%cell_size) : (i64) -> !llvm.ptr
%cell_head_ptr = llvm.getelementptr %cell[0, 0] : (!llvm.ptr) -> !llvm.ptr
llvm.store %transformed, %cell_head_ptr : i32, !llvm.ptr
%cell_tail_ptr = llvm.getelementptr %cell[0, 1] : (!llvm.ptr) -> !llvm.ptr
llvm.store %mapped_tail, %cell_tail_ptr : !llvm.struct<(i32, ptr)>, !llvm.ptr
// Build list struct: {tag=1, payload=cell}
%cons_tag_val = arith.constant 1 : i32
%new_list = llvm.mlir.undef : !llvm.struct<(i32, ptr)>
%new_list1 = llvm.insertvalue %cons_tag_val, %new_list[0] : !llvm.struct<(i32, ptr)>
%new_list2 = llvm.insertvalue %cell, %new_list1[1] : !llvm.struct<(i32, ptr)>
scf.yield %new_list2 : !llvm.struct<(i32, ptr)>
}
return %result : !llvm.struct<(i32, ptr)>
}
변환 내용:
funlang.match→scf.if: Binary choice (Nil vs Cons)- Tag extraction:
llvm.getelementptr+llvm.load로 tag field 읽기 - Comparison:
arith.cmpi eq로 tag 검사 - Block arguments → loads: Cons case의
%head,%tail을 payload에서 추출 - GC allocation:
GC_malloc으로 새 cons cell 할당
Lowering Stage 2: SCF → CF + LLVM
SCFToControlFlowPass가 실행되면 scf.if가 cf.br, cf.cond_br로 lowering된다:
func.func @map(%f: !funlang.closure<(i32) -> i32>,
%lst: !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)> {
^entry:
// Extract tag
%tag_ptr = llvm.getelementptr %lst[0, 0] : (!llvm.struct<(i32, ptr)>) -> !llvm.ptr
%tag = llvm.load %tag_ptr : !llvm.ptr -> i32
%c0 = arith.constant 0 : i32
%is_nil = arith.cmpi eq, %tag, %c0 : i32
// Conditional branch
cf.cond_br %is_nil, ^nil_case, ^cons_case
^nil_case:
// Return empty list
%nil_tag = arith.constant 0 : i32
%null_ptr = llvm.mlir.null : !llvm.ptr
%empty = llvm.mlir.undef : !llvm.struct<(i32, ptr)>
%empty1 = llvm.insertvalue %nil_tag, %empty[0] : !llvm.struct<(i32, ptr)>
%empty2 = llvm.insertvalue %null_ptr, %empty1[1] : !llvm.struct<(i32, ptr)>
cf.br ^exit(%empty2 : !llvm.struct<(i32, ptr)>)
^cons_case:
// Extract head and tail
%payload_ptr = llvm.getelementptr %lst[0, 1] : (!llvm.struct<(i32, ptr)>) -> !llvm.ptr
%payload = llvm.load %payload_ptr : !llvm.ptr -> !llvm.ptr
%head_ptr = llvm.getelementptr %payload[0, 0] : (!llvm.ptr) -> !llvm.ptr
%head = llvm.load %head_ptr : !llvm.ptr -> i32
%tail_ptr = llvm.getelementptr %payload[0, 1] : (!llvm.ptr) -> !llvm.ptr
%tail = llvm.load %tail_ptr : !llvm.ptr -> !llvm.struct<(i32, ptr)>
// Apply closure
%transformed = funlang.apply %f(%head) : (!funlang.closure<(i32) -> i32>, i32) -> i32
// Recursive call
%mapped_tail = func.call @map(%f, %tail)
: (!funlang.closure<(i32) -> i32>, !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)>
// Allocate cons cell
%cell_size = llvm.mlir.constant(16 : i64) : i64
%cell = llvm.call @GC_malloc(%cell_size) : (i64) -> !llvm.ptr
%cell_head_ptr = llvm.getelementptr %cell[0, 0] : (!llvm.ptr) -> !llvm.ptr
llvm.store %transformed, %cell_head_ptr : i32, !llvm.ptr
%cell_tail_ptr = llvm.getelementptr %cell[0, 1] : (!llvm.ptr) -> !llvm.ptr
llvm.store %mapped_tail, %cell_tail_ptr : !llvm.struct<(i32, ptr)>, !llvm.ptr
// Build result
%cons_tag = arith.constant 1 : i32
%new_list = llvm.mlir.undef : !llvm.struct<(i32, ptr)>
%new_list1 = llvm.insertvalue %cons_tag, %new_list[0] : !llvm.struct<(i32, ptr)>
%new_list2 = llvm.insertvalue %cell, %new_list1[1] : !llvm.struct<(i32, ptr)>
cf.br ^exit(%new_list2 : !llvm.struct<(i32, ptr)>)
^exit(%result: !llvm.struct<(i32, ptr)>):
return %result : !llvm.struct<(i32, ptr)>
}
CFG 구조:
[entry]
|
(is_nil?)
/ \
[nil] [cons]
\ /
[exit]
테스트 프로그램: map (fun x -> x * 2) [1, 2, 3]
완전한 프로그램을 컴파일하고 실행해보자:
// FunLang source
let double = fun x -> x * 2
let rec map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (map f tail)
let result = map double [1, 2, 3]
// Expected: [2, 4, 6]
Compiled MLIR (simplified):
module {
// Helper: double function as closure implementation
func.func @double_impl(%x: i32) -> i32 {
%c2 = arith.constant 2 : i32
%result = arith.muli %x, %c2 : i32
return %result : i32
}
// map function (as defined above)
func.func @map(%f: !funlang.closure<(i32) -> i32>,
%lst: !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)> {
// ... (as shown in previous section)
}
// Main entry point
func.func @main() -> !llvm.struct<(i32, ptr)> {
// Create closure: double
%double_fn = llvm.mlir.addressof @double_impl : !llvm.ptr
%null_env = llvm.mlir.null : !llvm.ptr // no captures
%closure_size = llvm.mlir.constant(16 : i64) : i64
%closure_mem = llvm.call @GC_malloc(%closure_size) : (i64) -> !llvm.ptr
%fn_ptr_field = llvm.getelementptr %closure_mem[0, 0] : (!llvm.ptr) -> !llvm.ptr
llvm.store %double_fn, %fn_ptr_field : !llvm.ptr, !llvm.ptr
%env_ptr_field = llvm.getelementptr %closure_mem[0, 1] : (!llvm.ptr) -> !llvm.ptr
llvm.store %null_env, %env_ptr_field : !llvm.ptr, !llvm.ptr
%double = llvm.load %closure_mem : !llvm.ptr -> !funlang.closure<(i32) -> i32>
// Create list: [1, 2, 3]
%c1 = arith.constant 1 : i32
%c2 = arith.constant 2 : i32
%c3 = arith.constant 3 : i32
%nil = funlang.nil : !funlang.list<i32>
%l3 = funlang.cons %c3, %nil : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l2 = funlang.cons %c2, %l3 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l1 = funlang.cons %c1, %l2 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
// Call map
%result = func.call @map(%double, %l1)
: (!funlang.closure<(i32) -> i32>, !funlang.list<i32>) -> !funlang.list<i32>
return %result : !llvm.struct<(i32, ptr)>
}
}
실행 trace:
map double [1, 2, 3]
→ double 1 :: map double [2, 3]
→ 2 :: (double 2 :: map double [3])
→ 2 :: (4 :: (double 3 :: map double []))
→ 2 :: (4 :: (6 :: []))
→ [2, 4, 6]
Memory layout (heap):
Closure (double):
+0: fn_ptr -> @double_impl
+8: env_ptr -> NULL
List [2, 4, 6]:
+0: tag=1, payload -> ConsCell1
ConsCell1:
+0: head=2
+8: tail -> {tag=1, payload -> ConsCell2}
ConsCell2:
+0: head=4
+8: tail -> {tag=1, payload -> ConsCell3}
ConsCell3:
+0: head=6
+8: tail -> {tag=0, payload=NULL} // Nil
검증: JIT 실행
// Compiler.fs
let testMapDouble() =
let ctx = MLIRContext.Create()
let module = compileProgram ctx mapDoubleSource
// Apply lowering passes
let pm = PassManager.Create(ctx)
pm.AddPass("convert-funlang-to-scf")
pm.AddPass("convert-scf-to-cf")
pm.AddPass("convert-funlang-to-llvm")
pm.Run(module)
// JIT execute
let engine = ExecutionEngine.Create(module)
let result = engine.Invoke("main", [||])
// Verify result: [2, 4, 6]
let list = result :?> ListValue
assert (list.Count = 3)
assert (list.[0] = 2)
assert (list.[1] = 4)
assert (list.[2] = 6)
printfn "map double [1, 2, 3] = [2, 4, 6] ✓"
Output:
map double [1, 2, 3] = [2, 4, 6] ✓
성공! map 함수가 완전히 작동한다.
filter 함수: 조건부 리스트 필터링
filter의 개념
filter는 조건을 만족하는 원소만 남긴다:
// filter의 타입
filter : (a -> bool) -> [a] -> [a]
// filter의 의미
filter pred [x1, x2, ..., xn] = [xi | pred xi = true]
예제:
let is_positive x = x > 0
filter is_positive [-2, -1, 0, 1, 2] // [1, 2]
let is_even x = x % 2 == 0
filter is_even [1, 2, 3, 4, 5, 6] // [2, 4, 6]
filter (fun x -> x > 2) [1, 2, 3, 4] // [3, 4]
FunLang 소스 코드
filter 함수를 FunLang으로 작성한다:
let rec filter pred lst =
match lst with
| [] -> []
| head :: tail ->
if pred head then
head :: filter pred tail
else
filter pred tail
동작 원리:
- Base case: Empty list → return empty list
- Recursive case:
- 조건 검사:
pred head - True이면:
head를 결과에 포함 - False이면:
head를 건너뛰고 tail만 재귀 처리
- 조건 검사:
실행 trace:
filter (fun x -> x > 2) [1, 2, 3, 4]
→ (1 > 2)? No → filter pred [2, 3, 4]
→ (2 > 2)? No → filter pred [3, 4]
→ (3 > 2)? Yes → 3 :: filter pred [4]
→ (4 > 2)? Yes → 3 :: (4 :: filter pred [])
→ 3 :: (4 :: [])
→ [3, 4]
map vs filter 비교
| 특성 | map | filter |
|---|---|---|
| 타입 | (a -> b) -> [a] -> [b] | (a -> bool) -> [a] -> [a] |
| 결과 크기 | Input과 동일 | Input 이하 |
| 조건 분기 | 없음 (항상 변환) | 있음 (if-else) |
| 원소 변환 | 있음 (f x) | 없음 (원소 그대로) |
| MLIR 복잡도 | Moderate | Higher (nested control flow) |
컴파일된 MLIR: FunLang Dialect
// filter : (T -> i1) -> !funlang.list<T> -> !funlang.list<T>
func.func @filter(%pred: !funlang.closure<(i32) -> i1>,
%lst: !funlang.list<i32>) -> !funlang.list<i32> {
// match lst with ...
%result = funlang.match %lst : !funlang.list<i32> -> !funlang.list<i32> {
// Case 1: [] -> []
^nil:
%empty = funlang.nil : !funlang.list<i32>
funlang.yield %empty : !funlang.list<i32>
// Case 2: head :: tail -> if (pred head) then ... else ...
^cons(%head: i32, %tail: !funlang.list<i32>):
// pred head
%should_keep = funlang.apply %pred(%head)
: (!funlang.closure<(i32) -> i1>, i32) -> i1
// Recursive call (always needed)
%filtered_tail = func.call @filter(%pred, %tail)
: (!funlang.closure<(i32) -> i1>, !funlang.list<i32>) -> !funlang.list<i32>
// if should_keep then head :: filtered_tail else filtered_tail
%new_list = scf.if %should_keep -> !funlang.list<i32> {
// Keep head
%kept = funlang.cons %head, %filtered_tail
: (i32, !funlang.list<i32>) -> !funlang.list<i32>
scf.yield %kept : !funlang.list<i32>
} else {
// Skip head
scf.yield %filtered_tail : !funlang.list<i32>
}
funlang.yield %new_list : !funlang.list<i32>
}
return %result : !funlang.list<i32>
}
핵심 포인트:
- Nested control flow:
funlang.match안에scf.if - Predicate 호출:
funlang.apply %pred(%head)는 boolean 반환 - Conditional cons: True일 때만
funlang.cons - Recursive call position: if 밖에서 호출 (항상 필요)
Nested Control Flow 분석
filter는 두 단계의 제어 흐름을 가진다:
Level 1: Pattern matching
match lst:
Nil → []
Cons → [Level 2]
Level 2: Conditional inclusion
if pred head:
True → head :: filtered_tail
False → filtered_tail
Combined CFG:
[entry]
|
(is_nil?)
/ \
[nil] [cons]
| |
| (pred head?)
| / \
| [keep] [skip]
| \ /
| [merge]
\ /
[exit]
Lowering Stage 1: FunLang → SCF
func.func @filter(%pred: !funlang.closure<(i32) -> i1>,
%lst: !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)> {
// Extract tag
%tag_ptr = llvm.getelementptr %lst[0, 0] : (!llvm.struct<(i32, ptr)>) -> !llvm.ptr
%tag = llvm.load %tag_ptr : !llvm.ptr -> i32
%c0 = arith.constant 0 : i32
%is_nil = arith.cmpi eq, %tag, %c0 : i32
// Level 1: match
%result = scf.if %is_nil -> !llvm.struct<(i32, ptr)> {
// Nil case
%nil_tag = arith.constant 0 : i32
%null_ptr = llvm.mlir.null : !llvm.ptr
%empty = llvm.mlir.undef : !llvm.struct<(i32, ptr)>
%empty1 = llvm.insertvalue %nil_tag, %empty[0] : !llvm.struct<(i32, ptr)>
%empty2 = llvm.insertvalue %null_ptr, %empty1[1] : !llvm.struct<(i32, ptr)>
scf.yield %empty2 : !llvm.struct<(i32, ptr)>
} else {
// Cons case: extract head and tail
%payload_ptr = llvm.getelementptr %lst[0, 1] : (!llvm.struct<(i32, ptr)>) -> !llvm.ptr
%payload = llvm.load %payload_ptr : !llvm.ptr -> !llvm.ptr
%head_ptr = llvm.getelementptr %payload[0, 0] : (!llvm.ptr) -> !llvm.ptr
%head = llvm.load %head_ptr : !llvm.ptr -> i32
%tail_ptr = llvm.getelementptr %payload[0, 1] : (!llvm.ptr) -> !llvm.ptr
%tail = llvm.load %tail_ptr : !llvm.ptr -> !llvm.struct<(i32, ptr)>
// Apply predicate
%should_keep = funlang.apply %pred(%head)
: (!funlang.closure<(i32) -> i1>, i32) -> i1
// Recursive call
%filtered_tail = func.call @filter(%pred, %tail)
: (!funlang.closure<(i32) -> i1>, !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)>
// Level 2: if pred
%new_list = scf.if %should_keep -> !llvm.struct<(i32, ptr)> {
// Keep: allocate cons cell
%cell_size = llvm.mlir.constant(16 : i64) : i64
%cell = llvm.call @GC_malloc(%cell_size) : (i64) -> !llvm.ptr
%cell_head_ptr = llvm.getelementptr %cell[0, 0] : (!llvm.ptr) -> !llvm.ptr
llvm.store %head, %cell_head_ptr : i32, !llvm.ptr
%cell_tail_ptr = llvm.getelementptr %cell[0, 1] : (!llvm.ptr) -> !llvm.ptr
llvm.store %filtered_tail, %cell_tail_ptr : !llvm.struct<(i32, ptr)>, !llvm.ptr
%cons_tag = arith.constant 1 : i32
%kept = llvm.mlir.undef : !llvm.struct<(i32, ptr)>
%kept1 = llvm.insertvalue %cons_tag, %kept[0] : !llvm.struct<(i32, ptr)>
%kept2 = llvm.insertvalue %cell, %kept1[1] : !llvm.struct<(i32, ptr)>
scf.yield %kept2 : !llvm.struct<(i32, ptr)>
} else {
// Skip: return filtered_tail directly
scf.yield %filtered_tail : !llvm.struct<(i32, ptr)>
}
scf.yield %new_list : !llvm.struct<(i32, ptr)>
}
return %result : !llvm.struct<(i32, ptr)>
}
Nested scf.if analysis:
- Outer if: 리스트가 empty인지 검사
- Inner if: Head를 keep할지 skip할지 결정
- Region nesting: Inner if는 outer if의 else branch 안에 있다
- Type consistency: 모든 branch가 같은 타입 반환
테스트 프로그램: filter (fun x -> x > 2) [1, 2, 3, 4]
// FunLang source
let is_greater_than_2 = fun x -> x > 2
let rec filter pred lst =
match lst with
| [] -> []
| head :: tail ->
if pred head then
head :: filter pred tail
else
filter pred tail
let result = filter is_greater_than_2 [1, 2, 3, 4]
// Expected: [3, 4]
Compiled MLIR (main function):
func.func @main() -> !llvm.struct<(i32, ptr)> {
// Create predicate closure: fun x -> x > 2
%pred_fn = llvm.mlir.addressof @is_greater_than_2_impl : !llvm.ptr
%null_env = llvm.mlir.null : !llvm.ptr
%closure_size = llvm.mlir.constant(16 : i64) : i64
%closure_mem = llvm.call @GC_malloc(%closure_size) : (i64) -> !llvm.ptr
%fn_ptr_field = llvm.getelementptr %closure_mem[0, 0] : (!llvm.ptr) -> !llvm.ptr
llvm.store %pred_fn, %fn_ptr_field : !llvm.ptr, !llvm.ptr
%env_ptr_field = llvm.getelementptr %closure_mem[0, 1] : (!llvm.ptr) -> !llvm.ptr
llvm.store %null_env, %env_ptr_field : !llvm.ptr, !llvm.ptr
%pred = llvm.load %closure_mem : !llvm.ptr -> !funlang.closure<(i32) -> i1>
// Create list: [1, 2, 3, 4]
%c1 = arith.constant 1 : i32
%c2 = arith.constant 2 : i32
%c3 = arith.constant 3 : i32
%c4 = arith.constant 4 : i32
%nil = funlang.nil : !funlang.list<i32>
%l4 = funlang.cons %c4, %nil : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l3 = funlang.cons %c3, %l4 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l2 = funlang.cons %c2, %l3 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l1 = funlang.cons %c1, %l2 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
// Call filter
%result = func.call @filter(%pred, %l1)
: (!funlang.closure<(i32) -> i1>, !funlang.list<i32>) -> !funlang.list<i32>
return %result : !llvm.struct<(i32, ptr)>
}
// Predicate implementation
func.func @is_greater_than_2_impl(%x: i32) -> i1 {
%c2 = arith.constant 2 : i32
%result = arith.cmpi sgt, %x, %c2 : i32
return %result : i1
}
실행 trace:
filter pred [1, 2, 3, 4]
→ (1 > 2)? No → filter pred [2, 3, 4]
→ (2 > 2)? No → filter pred [3, 4]
→ (3 > 2)? Yes → 3 :: filter pred [4]
→ (4 > 2)? Yes → 3 :: (4 :: filter pred [])
→ 3 :: (4 :: [])
→ [3, 4]
검증:
let testFilterGreaterThan2() =
let ctx = MLIRContext.Create()
let module = compileProgram ctx filterSource
let pm = PassManager.Create(ctx)
pm.AddPass("convert-funlang-to-scf")
pm.AddPass("convert-scf-to-cf")
pm.AddPass("convert-funlang-to-llvm")
pm.Run(module)
let engine = ExecutionEngine.Create(module)
let result = engine.Invoke("main", [||])
let list = result :?> ListValue
assert (list.Count = 2)
assert (list.[0] = 3)
assert (list.[1] = 4)
printfn "filter (fun x -> x > 2) [1, 2, 3, 4] = [3, 4] ✓"
Output:
filter (fun x -> x > 2) [1, 2, 3, 4] = [3, 4] ✓
성공! filter 함수도 완전히 작동한다.
Helper 함수: length와 append
map과 filter 외에도 유용한 리스트 함수가 많다. 두 가지 기본 helper를 구현한다.
length 함수
FunLang 소스:
let rec length lst =
match lst with
| [] -> 0
| head :: tail -> 1 + length tail
타입: [a] -> int
예제:
length [] // 0
length [1] // 1
length [1, 2, 3] // 3
Compiled MLIR:
func.func @length(%lst: !funlang.list<i32>) -> i32 {
%result = funlang.match %lst : !funlang.list<i32> -> i32 {
^nil:
%zero = arith.constant 0 : i32
funlang.yield %zero : i32
^cons(%head: i32, %tail: !funlang.list<i32>):
%tail_len = func.call @length(%tail) : (!funlang.list<i32>) -> i32
%one = arith.constant 1 : i32
%len = arith.addi %one, %tail_len : i32
funlang.yield %len : i32
}
return %result : i32
}
특징:
head값은 무시 (타입만 필요)- 재귀 호출로 tail length 계산
- 결과:
1 + tail_length
append 함수
FunLang 소스:
let rec append xs ys =
match xs with
| [] -> ys
| head :: tail -> head :: (append tail ys)
타입: [a] -> [a] -> [a]
예제:
append [] [1, 2] // [1, 2]
append [1, 2] [] // [1, 2]
append [1, 2] [3, 4] // [1, 2, 3, 4]
Compiled MLIR:
func.func @append(%xs: !funlang.list<i32>,
%ys: !funlang.list<i32>) -> !funlang.list<i32> {
%result = funlang.match %xs : !funlang.list<i32> -> !funlang.list<i32> {
^nil:
// Base case: [] ++ ys = ys
funlang.yield %ys : !funlang.list<i32>
^cons(%head: i32, %tail: !funlang.list<i32>):
// Recursive case: (h :: t) ++ ys = h :: (t ++ ys)
%appended = func.call @append(%tail, %ys)
: (!funlang.list<i32>, !funlang.list<i32>) -> !funlang.list<i32>
%new_list = funlang.cons %head, %appended
: (i32, !funlang.list<i32>) -> !funlang.list<i32>
funlang.yield %new_list : !funlang.list<i32>
}
return %result : !funlang.list<i32>
}
특징:
- Base case: 첫 번째 리스트가 empty이면 두 번째 리스트 반환
- Recursive case: 첫 번째 리스트의 head를 보존하고 tail 재귀 처리
- 시간 복잡도: O(|xs|) - 첫 번째 리스트 길이에 비례
실행 trace:
append [1, 2] [3, 4]
→ 1 :: append [2] [3, 4]
→ 1 :: (2 :: append [] [3, 4])
→ 1 :: (2 :: [3, 4])
→ [1, 2, 3, 4]
테스트: Helper 함수
let testHelpers() =
// Test length
let len1 = length [] // 0
let len2 = length [1] // 1
let len3 = length [1, 2, 3] // 3
assert (len1 = 0)
assert (len2 = 1)
assert (len3 = 3)
printfn "length tests passed ✓"
// Test append
let app1 = append [] [1, 2] // [1, 2]
let app2 = append [1, 2] [] // [1, 2]
let app3 = append [1, 2] [3, 4] // [1, 2, 3, 4]
assert (listEqual app1 [1, 2])
assert (listEqual app2 [1, 2])
assert (listEqual app3 [1, 2, 3, 4])
printfn "append tests passed ✓"
Output:
length tests passed ✓
append tests passed ✓
이제 우리는 기본 함수형 프로그래밍 toolkit을 갖췄다:
map: 변환filter: 필터링length: 크기 측정append: 결합
다음 섹션에서는 가장 강력한 combinator인 **fold**를 구현한다.
fold 함수: 일반적인 리스트 Combinator
fold의 개념
fold (또는 reduce)는 리스트를 하나의 값으로 축약하는 가장 일반적인 combinator다:
// fold의 타입
fold : (acc -> a -> acc) -> acc -> [a] -> acc
// fold의 의미
fold f acc [x1, x2, ..., xn] = f (... (f (f acc x1) x2) ...) xn
fold는 모든 리스트 연산의 기초다:
// sum: 모든 원소의 합
let sum lst = fold (+) 0 lst
sum [1, 2, 3, 4, 5] // 15
// product: 모든 원소의 곱
let product lst = fold (*) 1 lst
product [1, 2, 3, 4] // 24
// length: map과 filter도 fold로 구현 가능
let length lst = fold (fun acc _ -> acc + 1) 0 lst
length [1, 2, 3] // 3
왜 fold가 가장 강력한가?
| 함수 | fold로 구현 가능? | 예제 |
|---|---|---|
sum | ✓ | fold (+) 0 |
product | ✓ | fold (*) 1 |
length | ✓ | fold (fun acc _ -> acc + 1) 0 |
map | ✓ | fold (fun acc x -> acc ++ [f x]) [] |
filter | ✓ | fold (fun acc x -> if p x then acc ++ [x] else acc) [] |
reverse | ✓ | fold (fun acc x -> x :: acc) [] |
fold는 universal list combinator다. 다른 모든 리스트 함수를 fold로 표현할 수 있다.
FunLang 소스 코드
fold 함수를 FunLang으로 작성한다:
let rec fold f acc lst =
match lst with
| [] -> acc
| head :: tail -> fold f (f acc head) tail
동작 원리:
- Base case: Empty list → return accumulator (결과)
- Recursive case:
- Apply
ftoaccandhead→ new accumulator - Recursively fold over
tailwith new accumulator
- Apply
실행 trace:
fold (+) 0 [1, 2, 3, 4, 5]
→ fold (+) (0 + 1) [2, 3, 4, 5]
→ fold (+) 1 [2, 3, 4, 5]
→ fold (+) (1 + 2) [3, 4, 5]
→ fold (+) 3 [3, 4, 5]
→ fold (+) (3 + 3) [4, 5]
→ fold (+) 6 [4, 5]
→ fold (+) (6 + 4) [5]
→ fold (+) 10 [5]
→ fold (+) (10 + 5) []
→ fold (+) 15 []
→ 15
Accumulator 패턴:
Accumulator는 중간 결과를 저장하는 변수다:
- 초기값:
acc = 0(sum의 경우) - 갱신:
acc = f acc head(각 원소마다 업데이트) - 최종값: 리스트가 empty일 때 accumulator 반환
fold vs map/filter 비교
| 특성 | map | filter | fold |
|---|---|---|---|
| 타입 | (a -> b) -> [a] -> [b] | (a -> bool) -> [a] -> [a] | (acc -> a -> acc) -> acc -> [a] -> acc |
| 입력 | 리스트 | 리스트 | 리스트 + 초기값 |
| 출력 | 리스트 | 리스트 | 단일 값 |
| 함수 인자 | 1개 (원소) | 1개 (원소) | 2개 (누적값, 원소) |
| 일반성 | 특수 | 특수 | 일반 (map/filter 구현 가능) |
컴파일된 MLIR: FunLang Dialect
// fold : (acc -> T -> acc) -> acc -> !funlang.list<T> -> acc
func.func @fold(%f: !funlang.closure<(i32, i32) -> i32>,
%acc: i32,
%lst: !funlang.list<i32>) -> i32 {
// match lst with ...
%result = funlang.match %lst : !funlang.list<i32> -> i32 {
// Case 1: [] -> acc
^nil:
funlang.yield %acc : i32
// Case 2: head :: tail -> fold f (f acc head) tail
^cons(%head: i32, %tail: !funlang.list<i32>):
// f acc head
%new_acc = funlang.apply %f(%acc, %head)
: (!funlang.closure<(i32, i32) -> i32>, i32, i32) -> i32
// fold f new_acc tail (tail recursion!)
%final = func.call @fold(%f, %new_acc, %tail)
: (!funlang.closure<(i32, i32) -> i32>, i32, !funlang.list<i32>) -> i32
funlang.yield %final : i32
}
return %result : i32
}
핵심 포인트:
- Three arguments: 클로저
f, 누적값acc, 리스트lst - Binary closure:
f는 두 인자 (acc,head)를 받는다 - Tail recursion: 재귀 호출이 함수의 마지막 operation (최적화 가능!)
- Accumulator threading:
acc→new_acc→final로 흐름
Tail Recursion 분석
fold는 tail recursive다:
// Tail recursive (good)
let rec fold f acc lst =
match lst with
| [] -> acc
| head :: tail -> fold f (f acc head) tail
// ^^^ Recursive call is the LAST operation
// NOT tail recursive (map, filter)
let rec map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (map f tail)
// ^^^ Recursive call is NOT the last (cons follows)
Tail recursion의 장점:
- Stack frame 재사용: 각 재귀 호출이 새 stack frame을 생성하지 않음
- 메모리 효율: O(1) stack space (vs O(n) for non-tail)
- 컴파일러 최적화: Loop로 변환 가능
LLVM optimization pass가 tail call을 감지하면:
// Before optimization (recursive)
%result = func.call @fold(%f, %new_acc, %tail) : (...) -> i32
// After optimization (loop)
// Stack frame 재사용, jump로 변환
Common Fold Patterns
1. Sum (합계)
let sum lst = fold (fun acc x -> acc + x) 0 lst
// Or simply: fold (+) 0 lst
sum [1, 2, 3, 4, 5] // 15
Compiled MLIR:
func.func @sum(%lst: !funlang.list<i32>) -> i32 {
// Create add closure
%add = funlang.closure @add_impl() : () -> ((i32, i32) -> i32)
// Initial accumulator
%zero = arith.constant 0 : i32
// Call fold
%result = func.call @fold(%add, %zero, %lst)
: (!funlang.closure<(i32, i32) -> i32>, i32, !funlang.list<i32>) -> i32
return %result : i32
}
func.func @add_impl(%acc: i32, %x: i32) -> i32 {
%result = arith.addi %acc, %x : i32
return %result : i32
}
2. Product (곱셈)
let product lst = fold (*) 1 lst
product [1, 2, 3, 4] // 24
3. Length (길이)
let length lst = fold (fun acc _ -> acc + 1) 0 lst
length [1, 2, 3] // 3
이전에 재귀로 구현한 length와 같은 결과지만, fold를 사용하면 더 일반적이다.
4. Reverse (역순)
let reverse lst = fold (fun acc x -> x :: acc) [] lst
reverse [1, 2, 3] // [3, 2, 1]
Trace:
fold cons [] [1, 2, 3]
→ fold cons (1 :: []) [2, 3]
→ fold cons [1] [2, 3]
→ fold cons (2 :: [1]) [3]
→ fold cons [2, 1] [3]
→ fold cons (3 :: [2, 1]) []
→ fold cons [3, 2, 1] []
→ [3, 2, 1]
5. Maximum (최댓값)
let max_list lst =
match lst with
| [] -> error "empty list"
| head :: tail -> fold (fun acc x -> if x > acc then x else acc) head tail
max_list [3, 1, 4, 1, 5, 9, 2] // 9
테스트 프로그램: fold (+) 0 [1, 2, 3, 4, 5]
// FunLang source
let add = fun acc x -> acc + x
let rec fold f acc lst =
match lst with
| [] -> acc
| head :: tail -> fold f (f acc head) tail
let result = fold add 0 [1, 2, 3, 4, 5]
// Expected: 15
Compiled MLIR (main function):
func.func @main() -> i32 {
// Create add closure
%add_fn = llvm.mlir.addressof @add_impl : !llvm.ptr
%null_env = llvm.mlir.null : !llvm.ptr
%closure_size = llvm.mlir.constant(16 : i64) : i64
%closure_mem = llvm.call @GC_malloc(%closure_size) : (i64) -> !llvm.ptr
%fn_ptr_field = llvm.getelementptr %closure_mem[0, 0] : (!llvm.ptr) -> !llvm.ptr
llvm.store %add_fn, %fn_ptr_field : !llvm.ptr, !llvm.ptr
%env_ptr_field = llvm.getelementptr %closure_mem[0, 1] : (!llvm.ptr) -> !llvm.ptr
llvm.store %null_env, %env_ptr_field : !llvm.ptr, !llvm.ptr
%add = llvm.load %closure_mem : !llvm.ptr -> !funlang.closure<(i32, i32) -> i32>
// Initial accumulator
%zero = arith.constant 0 : i32
// Create list: [1, 2, 3, 4, 5]
%c1 = arith.constant 1 : i32
%c2 = arith.constant 2 : i32
%c3 = arith.constant 3 : i32
%c4 = arith.constant 4 : i32
%c5 = arith.constant 5 : i32
%nil = funlang.nil : !funlang.list<i32>
%l5 = funlang.cons %c5, %nil : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l4 = funlang.cons %c4, %l5 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l3 = funlang.cons %c3, %l4 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l2 = funlang.cons %c2, %l3 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
%l1 = funlang.cons %c1, %l2 : (i32, !funlang.list<i32>) -> !funlang.list<i32>
// Call fold
%result = func.call @fold(%add, %zero, %l1)
: (!funlang.closure<(i32, i32) -> i32>, i32, !funlang.list<i32>) -> i32
return %result : i32
}
func.func @add_impl(%acc: i32, %x: i32) -> i32 {
%result = arith.addi %acc, %x : i32
return %result : i32
}
검증:
let testFoldSum() =
let ctx = MLIRContext.Create()
let module = compileProgram ctx foldSumSource
let pm = PassManager.Create(ctx)
pm.AddPass("convert-funlang-to-scf")
pm.AddPass("convert-scf-to-cf")
pm.AddPass("convert-funlang-to-llvm")
pm.Run(module)
let engine = ExecutionEngine.Create(module)
let result = engine.Invoke("main", [||])
assert (result = 15)
printfn "fold (+) 0 [1, 2, 3, 4, 5] = 15 ✓"
Output:
fold (+) 0 [1, 2, 3, 4, 5] = 15 ✓
성공! fold 함수도 완전히 작동한다.
완전한 예제: Sum of Squares
이제 모든 것을 조합하여 실전 함수형 프로그램을 작성한다.
문제 정의
주어진 숫자 리스트의 제곱의 합을 계산한다:
sum_of_squares [1, 2, 3] = 1² + 2² + 3² = 1 + 4 + 9 = 14
FunLang 소스 코드
// Helper: square function
let square = fun x -> x * x
// Helper: add function
let add = fun acc x -> acc + x
// map: transform each element
let rec map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (map f tail)
// fold: reduce to single value
let rec fold f acc lst =
match lst with
| [] -> acc
| head :: tail -> fold f (f acc head) tail
// Composition: sum of squares
let sum_of_squares lst =
fold add 0 (map square lst)
// Test
let result = sum_of_squares [1, 2, 3]
// Expected: 14
함수 조합 분석:
[1, 2, 3]
↓ map square
[1, 4, 9]
↓ fold add 0
14
이것이 바로 함수형 프로그래밍의 핵심이다:
- 작은 함수들 (
square,add,map,fold) - 조합하여 복잡한 동작 (
sum_of_squares) - 선언적 스타일: “무엇을” 계산할지 명확
전체 컴파일 파이프라인 (9 단계)
이 프로그램을 end-to-end로 컴파일하는 과정을 모두 추적한다.
Stage 1: FunLang Source (사용자 작성)
let sum_of_squares lst =
fold add 0 (map square lst)
Stage 2: FunLang AST (Parser 출력)
Let("sum_of_squares",
Fun("lst",
App(App(App(Var "fold", Var "add"),
Int 0),
App(App(Var "map", Var "square"),
Var "lst"))),
...)
Stage 3: FunLang MLIR (Compiler.fs 출력)
func.func @sum_of_squares(%lst: !funlang.list<i32>) -> i32 {
// square closure (defined elsewhere)
%square = ... : !funlang.closure<(i32) -> i32>
// add closure (defined elsewhere)
%add = ... : !funlang.closure<(i32, i32) -> i32>
// map square lst
%squared_list = func.call @map(%square, %lst)
: (!funlang.closure<(i32) -> i32>, !funlang.list<i32>) -> !funlang.list<i32>
// fold add 0 squared_list
%zero = arith.constant 0 : i32
%result = func.call @fold(%add, %zero, %squared_list)
: (!funlang.closure<(i32, i32) -> i32>, i32, !funlang.list<i32>) -> i32
return %result : i32
}
Stage 4: FunLang → SCF Lowering (FunLangToSCFPass)
funlang.match operations이 scf.if로 변환된다:
// @map function (simplified)
func.func @map(...) -> ... {
%is_nil = ... : i1
%result = scf.if %is_nil -> ... {
// Nil case
scf.yield %empty : ...
} else {
// Cons case
%transformed = funlang.apply %f(%head) : ...
%mapped_tail = func.call @map(...) : ...
%new_list = funlang.cons %transformed, %mapped_tail : ...
scf.yield %new_list : ...
}
return %result : ...
}
Stage 5: FunLang Ops → LLVM (FunLangToLLVMPass)
funlang.cons, funlang.nil, funlang.apply 등이 LLVM operations로 변환:
// funlang.cons lowering
%cell_size = llvm.mlir.constant(16 : i64) : i64
%cell = llvm.call @GC_malloc(%cell_size) : (i64) -> !llvm.ptr
%head_ptr = llvm.getelementptr %cell[0, 0] : (!llvm.ptr) -> !llvm.ptr
llvm.store %head, %head_ptr : i32, !llvm.ptr
%tail_ptr = llvm.getelementptr %cell[0, 1] : (!llvm.ptr) -> !llvm.ptr
llvm.store %tail, %tail_ptr : !llvm.struct<(i32, ptr)>, !llvm.ptr
%cons_tag = arith.constant 1 : i32
%list = llvm.mlir.undef : !llvm.struct<(i32, ptr)>
%list1 = llvm.insertvalue %cons_tag, %list[0] : !llvm.struct<(i32, ptr)>
%list2 = llvm.insertvalue %cell, %list1[1] : !llvm.struct<(i32, ptr)>
Stage 6: SCF → CF Lowering (SCFToControlFlowPass)
scf.if → cf.cond_br, cf.br:
func.func @map(...) -> ... {
^entry:
%is_nil = ... : i1
cf.cond_br %is_nil, ^nil_case, ^cons_case
^nil_case:
%empty = ...
cf.br ^exit(%empty : ...)
^cons_case:
%transformed = ...
%mapped_tail = func.call @map(...) : ...
%new_list = ...
cf.br ^exit(%new_list : ...)
^exit(%result: ...):
return %result : ...
}
Stage 7: Func → LLVM (ConvertFuncToLLVMPass)
func.func → llvm.func, func.call → llvm.call:
llvm.func @map(%f: !llvm.ptr, %lst: !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)> {
...
%result = llvm.call @map(%f, %tail) : (!llvm.ptr, !llvm.struct<(i32, ptr)>) -> !llvm.struct<(i32, ptr)>
...
}
Stage 8: LLVM Dialect → LLVM IR (Translate to LLVM IR)
MLIR LLVM dialect를 실제 LLVM IR로 변환:
define { i32, i8* } @map({ i8*, i8* }* %f, { i32, i8* } %lst) {
entry:
%0 = extractvalue { i32, i8* } %lst, 0
%1 = icmp eq i32 %0, 0
br i1 %1, label %nil_case, label %cons_case
nil_case:
%2 = insertvalue { i32, i8* } undef, i32 0, 0
%3 = insertvalue { i32, i8* } %2, i8* null, 1
br label %exit
cons_case:
%4 = extractvalue { i32, i8* } %lst, 1
%5 = bitcast i8* %4 to { i32, { i32, i8* } }*
%6 = getelementptr { i32, { i32, i8* } }, { i32, { i32, i8* } }* %5, i32 0, i32 0
%7 = load i32, i32* %6
%8 = getelementptr { i32, { i32, i8* } }, { i32, { i32, i8* } }* %5, i32 0, i32 1
%9 = load { i32, i8* }, { i32, i8* }* %8
; ... (apply closure, recursive call, cons)
br label %exit
exit:
%result = phi { i32, i8* } [ %3, %nil_case ], [ %new_list, %cons_case ]
ret { i32, i8* } %result
}
Stage 9: LLVM IR → Machine Code (JIT 또는 AOT)
LLVM backend가 target architecture의 machine code 생성:
; x86-64 assembly (simplified)
map:
push rbp
mov rbp, rsp
; Extract tag
mov eax, dword ptr [rsi]
test eax, eax
je .LBB0_1 ; Nil case
; Cons case
mov rdi, qword ptr [rsi + 8]
mov ecx, dword ptr [rdi] ; head
mov rsi, qword ptr [rdi + 8] ; tail
; ... (apply f, recursive call)
jmp .LBB0_2
.LBB0_1:
; Return empty list
xor eax, eax
xor edx, edx
.LBB0_2:
pop rbp
ret
실행 및 검증
let testSumOfSquares() =
let ctx = MLIRContext.Create()
let module = compileProgram ctx sumOfSquaresSource
// Apply all passes
let pm = PassManager.Create(ctx)
pm.AddPass("convert-funlang-to-scf")
pm.AddPass("convert-scf-to-cf")
pm.AddPass("convert-funlang-to-llvm")
pm.AddPass("convert-func-to-llvm")
pm.Run(module)
// JIT compile and execute
let engine = ExecutionEngine.Create(module)
let result = engine.Invoke("main", [||])
// Verify
assert (result = 14)
printfn "sum_of_squares [1, 2, 3] = 14 ✓"
// Detailed trace
printfn "Pipeline trace:"
printfn " [1, 2, 3]"
printfn " → map square"
printfn " [1, 4, 9]"
printfn " → fold add 0"
printfn " 14 ✓"
Output:
sum_of_squares [1, 2, 3] = 14 ✓
Pipeline trace:
[1, 2, 3]
→ map square
[1, 4, 9]
→ fold add 0
14 ✓
완전한 컴파일러가 작동한다!
9단계의 변환을 거쳐 FunLang 소스 코드가 실행 가능한 machine code가 되었다.
성능 고려사항
Stack Usage in Recursive List Functions
리스트 함수는 재귀적이므로 stack 사용량이 중요하다.
Stack depth by function:
| 함수 | Stack depth | 이유 |
|---|---|---|
map | O(n) | Non-tail recursive (cons 후에 return) |
filter | O(n) | Non-tail recursive (cons 후에 return) |
fold | O(1) | Tail recursive (최적화 가능) |
length | O(n) | Non-tail recursive |
append | O(n) | Non-tail recursive |
Non-tail recursion example (map):
let rec map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (map f tail)
// ^^^ Cons operation AFTER recursive call
// Stack frame must be preserved until map returns
Call stack for map square [1, 2, 3]:
map [1, 2, 3]
map [2, 3]
map [3]
map []
return []
cons 9 []
return [9]
cons 4 [9]
return [4, 9]
cons 1 [4, 9]
return [1, 4, 9]
각 frame은 다음을 저장해야 한다:
- Return address
headvalue (cons를 위해)tailpointer
Tail recursion example (fold):
let rec fold f acc lst =
match lst with
| [] -> acc
| head :: tail -> fold f (f acc head) tail
// ^^^ Recursive call is LAST operation
// Stack frame can be REUSED
Call stack for fold add 0 [1, 2, 3]:
fold 0 [1, 2, 3]
fold 1 [2, 3] // Same stack frame, acc updated
fold 3 [3] // Same stack frame, acc updated
fold 6 [] // Same stack frame, acc updated
return 6
Only ONE stack frame!
Tail Call Optimization (TCO)
LLVM은 tail call을 감지하여 최적화할 수 있다.
Before TCO:
define i32 @fold(...) {
; ...
%new_acc = add i32 %acc, %head
%result = call i32 @fold(..., %new_acc, %tail)
ret i32 %result
}
After TCO:
define i32 @fold(...) {
entry:
br label %loop
loop:
; ...
%new_acc = add i32 %acc, %head
; Update arguments and jump (no new stack frame)
br label %loop
}
TCO 활성화:
// PassManager.fs
let pm = PassManager.Create(ctx)
// Add standard LLVM optimization passes
pm.AddPass("inline") // Inline small functions
pm.AddPass("simplifycfg") // Simplify control flow
pm.AddPass("tailcallelim") // Tail call elimination
pm.AddPass("mem2reg") // Promote memory to registers
pm.Run(module)
결과:
fold는 loop로 변환되어 O(1) stack 사용- 큰 리스트 (100,000+ elements)도 stack overflow 없이 처리 가능
GC Pressure
리스트 연산은 많은 메모리를 할당한다.
Allocation counts:
// Create list [1, 2, 3]
// - 3 cons cells = 3 * 16 bytes = 48 bytes
// map square [1, 2, 3]
// - Input: 3 cells (48 bytes)
// - Output: 3 NEW cells (48 bytes)
// - Total alive: 96 bytes (both lists live)
// fold add 0 (map square [1, 2, 3])
// - Input: 3 cells (48 bytes) from map
// - Output: i32 (4 bytes) - no new list!
// - GC can collect input list after fold
Allocation pattern by function:
| 함수 | 할당량 | 설명 |
|---|---|---|
map | O(n) cons cells | 새 리스트 생성 |
filter | O(k) cons cells (k ≤ n) | 조건 만족하는 원소만 |
fold | O(1) | 단일 값만 반환 |
append | O(n) cons cells | 첫 번째 리스트 복사 |
GC optimization:
// BAD: 중간 리스트가 메모리에 남는다
let result1 = map f1 lst
let result2 = map f2 result1
let result3 = map f3 result2
// result1, result2, result3 모두 메모리에 존재
// GOOD: Fusion으로 중간 리스트 제거 (Phase 7에서 다룸)
let result = map (f3 << f2 << f1) lst
// 단일 pass, 중간 리스트 없음
Phase 7 Preview: Optimization Opportunities
Phase 7에서 다룰 최적화:
1. List Fusion
// Before: 두 번 순회
map f (map g lst)
// After fusion: 한 번만 순회
map (f << g) lst
2. Deforestation
// Before: 중간 리스트 생성
fold h z (map f lst)
// After deforestation: 직접 계산
fold (fun acc x -> h acc (f x)) z lst
3. Tail Recursion Modulo Cons
// map을 tail recursive로 변환
let map f lst =
let rec loop acc lst =
match lst with
| [] -> reverse acc
| head :: tail -> loop ((f head) :: acc) tail
loop [] lst
4. Parallel Map
큰 리스트에 대해 map을 병렬화:
// Sequential
%result = scf.for %i = 0 to %n step 1 iter_args(%acc = %init) -> ... {
%elem = load %lst[%i]
%transformed = apply %f(%elem)
...
}
// Parallel (MLIR scf.parallel)
scf.parallel (%i) = (0) to (%n) step (1) {
%elem = load %lst[%i]
%transformed = apply %f(%elem)
store %transformed, %result[%i]
}
이러한 최적화는 Phase 7에서 MLIR transformation passes로 구현할 것이다.
완전한 컴파일러 통합
이제 모든 것을 통합하여 완전한 FunLang 컴파일러를 구축한다.
FunLang AST Type Extensions
최종 AST 정의:
// Ast.fs
module Ast
type Expr =
// Phase 1-2: Basics
| Int of int
| Float of float
| Bool of bool
| Var of string
| Add of Expr * Expr
| Sub of Expr * Expr
| Mul of Expr * Expr
| Div of Expr * Expr
| Lt of Expr * Expr
| Gt of Expr * Expr
| Eq of Expr * Expr
// Phase 3: Control flow and functions
| Let of string * Expr * Expr
| If of Expr * Expr * Expr
| LetRec of string * Expr * Expr
// Phase 4: Closures and higher-order functions
| Fun of string * Expr // lambda
| App of Expr * Expr // application
// Phase 6: Lists and pattern matching
| Nil // []
| Cons of Expr * Expr // head :: tail
| List of Expr list // [1, 2, 3] (syntactic sugar)
| Match of Expr * (Pattern * Expr) list
and Pattern =
| PVar of string // x (variable binding)
| PNil // [] (empty list)
| PCons of Pattern * Pattern // head :: tail (cons pattern)
| PWild // _ (wildcard)
| PInt of int // 42 (literal match)
| PBool of bool // true/false
type Program = Expr
Compiler.fs: compileExpr Complete Implementation
// Compiler.fs
module Compiler
open MLIR
open Ast
let rec compileExpr (builder: OpBuilder) (expr: Expr) (symbolTable: Map<string, Value>) : Value =
match expr with
// Phase 1-2: Arithmetic
| Int n ->
let ty = builder.GetI32Type()
builder.CreateConstantInt(ty, n)
| Float f ->
let ty = builder.GetF64Type()
builder.CreateConstantFloat(ty, f)
| Bool b ->
let ty = builder.GetI1Type()
builder.CreateConstantBool(ty, b)
| Var name ->
symbolTable.[name]
| Add (left, right) ->
let lhs = compileExpr builder left symbolTable
let rhs = compileExpr builder right symbolTable
builder.CreateAddI(lhs, rhs)
| Mul (left, right) ->
let lhs = compileExpr builder left symbolTable
let rhs = compileExpr builder right symbolTable
builder.CreateMulI(lhs, rhs)
// ... (other arithmetic ops)
// Phase 3: Let and If
| Let (name, value, body) ->
let val_result = compileExpr builder value symbolTable
let newSymbolTable = symbolTable.Add(name, val_result)
compileExpr builder body newSymbolTable
| If (cond, thenExpr, elseExpr) ->
let condVal = compileExpr builder cond symbolTable
let resultTy = inferType thenExpr symbolTable
builder.CreateScfIf(condVal, resultTy, fun thenBuilder ->
let thenResult = compileExpr thenBuilder thenExpr symbolTable
thenBuilder.CreateScfYield(thenResult)
, fun elseBuilder ->
let elseResult = compileExpr elseBuilder elseExpr symbolTable
elseBuilder.CreateScfYield(elseResult)
)
| LetRec (name, func, body) ->
// Create named function for recursion
let funcName = sprintf "_%s" name
let funcOp = compileFunctionDefinition builder funcName func symbolTable
let funcRef = builder.CreateFuncRef(funcOp)
let newSymbolTable = symbolTable.Add(name, funcRef)
compileExpr builder body newSymbolTable
// Phase 4: Closures
| Fun (param, body) ->
// Analyze free variables
let freeVars = analyzeFreeVars (Fun(param, body)) symbolTable
// Create closure implementation function
let implName = sprintf "_lambda_%d" (freshId())
let implFunc = createClosureImpl builder implName param body freeVars symbolTable
// Capture environment
let captures = freeVars |> List.map (fun v -> symbolTable.[v])
// Create closure object
builder.CreateClosure(implFunc, captures)
| App (func, arg) ->
let funcVal = compileExpr builder func symbolTable
let argVal = compileExpr builder arg symbolTable
builder.CreateApply(funcVal, argVal)
// Phase 6: Lists
| Nil ->
let elemTy = inferElementType expr symbolTable
let listTy = builder.GetListType(elemTy)
builder.CreateNil(listTy)
| Cons (head, tail) ->
let headVal = compileExpr builder head symbolTable
let tailVal = compileExpr builder tail symbolTable
let headTy = headVal.GetType()
let listTy = builder.GetListType(headTy)
builder.CreateCons(headVal, tailVal, listTy)
| List exprs ->
// Desugar to nested Cons
let desugared = desugarList exprs
compileExpr builder desugared symbolTable
| Match (scrutinee, cases) ->
compileMatch builder scrutinee cases symbolTable
and compileMatch (builder: OpBuilder) (scrutinee: Expr) (cases: (Pattern * Expr) list) (symbolTable: Map<string, Value>) : Value =
let scrutineeVal = compileExpr builder scrutinee symbolTable
let resultTy = inferType (snd cases.[0]) symbolTable
// Create funlang.match operation
builder.CreateMatch(scrutineeVal, resultTy, fun matchBuilder ->
cases |> List.map (fun (pattern, body) ->
match pattern with
| PNil ->
// Nil case: no block arguments
matchBuilder.CreateNilCase(fun caseBuilder ->
let result = compileExpr caseBuilder body symbolTable
caseBuilder.CreateYield(result)
)
| PCons (PVar headName, PVar tailName) ->
// Cons case: bind head and tail
let headTy = inferPatternType pattern symbolTable
let listTy = builder.GetListType(headTy)
matchBuilder.CreateConsCase(headTy, listTy, fun caseBuilder headArg tailArg ->
let newSymbolTable =
symbolTable
.Add(headName, headArg)
.Add(tailName, tailArg)
let result = compileExpr caseBuilder body newSymbolTable
caseBuilder.CreateYield(result)
)
| _ -> failwith "Unsupported pattern"
)
)
and desugarList (exprs: Expr list) : Expr =
match exprs with
| [] -> Nil
| head :: tail -> Cons(head, desugarList tail)
Type Inference for List Types
리스트 타입 추론:
// TypeInfer.fs
let rec inferType (expr: Expr) (symbolTable: Map<string, Value>) : MLIRType =
match expr with
| Int _ -> builder.GetI32Type()
| Float _ -> builder.GetF64Type()
| Bool _ -> builder.GetI1Type()
| Var name ->
let value = symbolTable.[name]
value.GetType()
| Nil ->
// Need context to infer element type
// If context is unavailable, default to i32
builder.GetListType(builder.GetI32Type())
| Cons (head, tail) ->
let headTy = inferType head symbolTable
builder.GetListType(headTy)
| List (head :: _) ->
let headTy = inferType head symbolTable
builder.GetListType(headTy)
| Match (scrutinee, cases) ->
// Result type is the type of first case body
inferType (snd cases.[0]) symbolTable
| Fun (param, body) ->
// Function type: paramTy -> returnTy
// Need type annotation or inference
let paramTy = inferParamType param
let returnTy = inferType body symbolTable
builder.GetFunctionType(paramTy, returnTy)
| _ -> failwith "Type inference not implemented"
End-to-End Compilation Function
// Pipeline.fs
let compileProgram (source: string) : MLIRModule =
// 1. Parse
let ast = Parser.parse source
// 2. Desugar
let desugared = Desugar.desugar ast
// 3. Type check
TypeChecker.check desugared
// 4. Compile to MLIR
let ctx = MLIRContext.Create()
let module = MLIRModule.Create(ctx)
let builder = OpBuilder.Create(ctx)
let mainFunc = builder.CreateFunc("main", [], inferType desugared Map.empty, fun funcBuilder ->
let result = Compiler.compileExpr funcBuilder desugared Map.empty
funcBuilder.CreateReturn(result)
)
module.AddFunction(mainFunc)
// 5. Apply lowering passes
let pm = PassManager.Create(ctx)
pm.AddPass("convert-funlang-to-scf")
pm.AddPass("convert-scf-to-cf")
pm.AddPass("convert-funlang-to-llvm")
pm.AddPass("convert-func-to-llvm")
pm.Run(module)
module
// Execute
let execute (module: MLIRModule) : obj =
let engine = ExecutionEngine.Create(module)
engine.Invoke("main", [||])
// Complete pipeline
let run (source: string) : obj =
let module = compileProgram source
execute module
Example Usage
// Main.fs
[<EntryPoint>]
let main argv =
let source = """
let square = fun x -> x * x
let add = fun acc x -> acc + x
let rec map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (map f tail)
let rec fold f acc lst =
match lst with
| [] -> acc
| head :: tail -> fold f (f acc head) tail
let sum_of_squares lst =
fold add 0 (map square lst)
sum_of_squares [1, 2, 3]
"""
let result = Pipeline.run source
printfn "Result: %A" result // Result: 14
0
Output:
Result: 14
완전한 컴파일러가 작동한다!
Common Errors and Debugging
함수형 프로그램 작성 시 자주 발생하는 오류와 해결 방법.
1. Infinite Recursion
오류:
let rec bad_map f lst =
match lst with
| [] -> []
| head :: tail -> (f head) :: (bad_map f lst) // BUG: lst instead of tail
증상:
Stack overflow
Segmentation fault
Infinite loop
해결:
- 재귀 호출이 “smaller” input을 사용하는지 확인
- Base case가 반드시 도달 가능한지 확인
// Correct
| head :: tail -> (f head) :: (map f tail) // ✓ tail is smaller
2. Type Mismatch
오류:
let bad_fold f acc lst =
match lst with
| [] -> 0 // BUG: should return acc, not 0
| head :: tail -> fold f (f acc head) tail
증상:
Type error: Expected i32, found i64
Type mismatch in match branches
해결:
- 모든 match branch가 같은 타입 반환하는지 확인
- Accumulator 타입이 일관되는지 확인
// Correct
| [] -> acc // ✓ Same type as recursive case
3. Wrong Accumulator Type
오류:
// Want to reverse a list
let reverse lst = fold (fun acc x -> acc :: x) [] lst // BUG: wrong cons order
증상:
Type error: Cannot cons list to element
Expected: element :: list
Found: list :: element
해결:
- Cons operator는
element :: list순서 - Accumulator 타입 확인
// Correct
let reverse lst = fold (fun acc x -> x :: acc) [] lst // ✓ x :: acc
4. Stack Overflow
오류:
// Large list
let big_list = [1..100000]
let result = map square big_list // Stack overflow!
증상:
Segmentation fault (core dumped)
Stack overflow at recursion depth 100000
해결:
- Tail recursive 버전 사용
- TCO 활성화
- Iteration으로 변환 (Phase 7)
// Tail recursive version
let map_tailrec f lst =
let rec loop acc lst =
match lst with
| [] -> reverse acc
| head :: tail -> loop ((f head) :: acc) tail
loop [] lst
5. Debugging Strategies
전략 1: Trace execution
let rec map f lst =
printfn "map called with list of length %d" (length lst)
match lst with
| [] ->
printfn " -> returning []"
[]
| head :: tail ->
printfn " -> transforming %A" head
let transformed = f head
printfn " -> recursing on tail"
let mapped_tail = map f tail
printfn " -> cons %A onto result" transformed
transformed :: mapped_tail
전략 2: Unit tests
let test_map() =
assert (map square [] = [])
assert (map square [1] = [1])
assert (map square [1, 2] = [1, 4])
assert (map square [1, 2, 3] = [1, 4, 9])
printfn "map tests passed ✓"
전략 3: MLIR inspection
let module = compileProgram source
printfn "%s" (module.ToString()) // Print MLIR before lowering
let pm = PassManager.Create(ctx)
pm.EnableIRPrinting() // Print after each pass
pm.AddPass("convert-funlang-to-scf")
pm.Run(module)
전략 4: GDB debugging
# Compile with debug info
mlir-opt --debug-only=funlang-to-scf input.mlir
# Run under GDB
gdb --args mlir-opt ...
(gdb) break FunLangToSCFPass::runOnOperation
(gdb) run
리터럴 패턴 예제: fizzbuzz
지금까지 리스트에 대한 constructor pattern (Nil, Cons)을 다뤘다. 이제 리터럴 패턴을 사용하는 실전 예제를 살펴본다.
FizzBuzz 문제
FizzBuzz 규칙:
- 3의 배수: “Fizz”
- 5의 배수: “Buzz”
- 15의 배수: “FizzBuzz”
- 그 외: 숫자 그대로
FunLang 구현:
let fizzbuzz n =
match (n % 3, n % 5) with
| (0, 0) -> "FizzBuzz"
| (0, _) -> "Fizz"
| (_, 0) -> "Buzz"
| (_, _) -> string_of_int n
패턴 분석:
| Row | n % 3 | n % 5 | Result |
|---|---|---|---|
| 1 | 0 | 0 | “FizzBuzz” |
| 2 | 0 | _ | “Fizz” |
| 3 | _ | 0 | “Buzz” |
| 4 | _ | _ | n |
컴파일된 MLIR: 리터럴 패턴
func.func @fizzbuzz(%n: i32) -> !llvm.ptr<i8> {
// Compute remainders
%c3 = arith.constant 3 : i32
%c5 = arith.constant 5 : i32
%c0 = arith.constant 0 : i32
%mod3 = arith.remsi %n, %c3 : i32
%mod5 = arith.remsi %n, %c5 : i32
// Pattern matching: sequential arith.cmpi chain
%is_div3 = arith.cmpi eq, %mod3, %c0 : i32
%result = scf.if %is_div3 -> !llvm.ptr<i8> {
// First column is 0 (n % 3 == 0)
%is_div5 = arith.cmpi eq, %mod5, %c0 : i32
%inner = scf.if %is_div5 -> !llvm.ptr<i8> {
// Case (0, 0): FizzBuzz
scf.yield %fizzbuzz_str : !llvm.ptr<i8>
} else {
// Case (0, _): Fizz
scf.yield %fizz_str : !llvm.ptr<i8>
}
scf.yield %inner : !llvm.ptr<i8>
} else {
// First column is not 0 (n % 3 != 0)
%is_div5_2 = arith.cmpi eq, %mod5, %c0 : i32
%inner2 = scf.if %is_div5_2 -> !llvm.ptr<i8> {
// Case (_, 0): Buzz
scf.yield %buzz_str : !llvm.ptr<i8>
} else {
// Case (_, _): n as string
%str = func.call @int_to_string(%n) : (i32) -> !llvm.ptr<i8>
scf.yield %str : !llvm.ptr<i8>
}
scf.yield %inner2 : !llvm.ptr<i8>
}
return %result : !llvm.ptr<i8>
}
핵심 관찰:
arith.cmpi eq: 리터럴 0과의 비교- Nested
scf.if: Decision tree 구조 - Wildcard
_: else branch로 fallthrough (테스트 없음)
classify 함수: 숫자 분류
숫자를 여러 카테고리로 분류하는 예제:
let classify n =
match n with
| 0 -> "zero"
| 1 -> "one"
| 2 -> "two"
| _ -> if n < 0 then "negative" else "many"
컴파일된 MLIR:
func.func @classify(%n: i32) -> !llvm.ptr<i8> {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c2 = arith.constant 2 : i32
// Sequential literal comparisons
%is_zero = arith.cmpi eq, %n, %c0 : i32
%result = scf.if %is_zero -> !llvm.ptr<i8> {
scf.yield %zero_str : !llvm.ptr<i8>
} else {
%is_one = arith.cmpi eq, %n, %c1 : i32
%r1 = scf.if %is_one -> !llvm.ptr<i8> {
scf.yield %one_str : !llvm.ptr<i8>
} else {
%is_two = arith.cmpi eq, %n, %c2 : i32
%r2 = scf.if %is_two -> !llvm.ptr<i8> {
scf.yield %two_str : !llvm.ptr<i8>
} else {
// Default case with guard
%is_neg = arith.cmpi slt, %n, %c0 : i32
%r3 = scf.if %is_neg -> !llvm.ptr<i8> {
scf.yield %negative_str : !llvm.ptr<i8>
} else {
scf.yield %many_str : !llvm.ptr<i8>
}
scf.yield %r3 : !llvm.ptr<i8>
}
scf.yield %r2 : !llvm.ptr<i8>
}
scf.yield %r1 : !llvm.ptr<i8>
}
return %result : !llvm.ptr<i8>
}
최적화: Dense Range Switch
리터럴이 0, 1, 2 연속일 때 scf.index_switch 최적화 가능:
// Optimized: range check + index_switch
%in_range = arith.cmpi ult, %n, %c3 : i32
%result = scf.if %in_range -> !llvm.ptr<i8> {
%idx = arith.index_cast %n : i32 to index
%r = scf.index_switch %idx : index -> !llvm.ptr<i8>
case 0 { scf.yield %zero_str : !llvm.ptr<i8> }
case 1 { scf.yield %one_str : !llvm.ptr<i8> }
case 2 { scf.yield %two_str : !llvm.ptr<i8> }
default { scf.yield %unreachable : !llvm.ptr<i8> }
scf.yield %r : !llvm.ptr<i8>
} else {
// n >= 3: check if negative
%is_neg = arith.cmpi slt, %n, %c0 : i32
%r2 = scf.if %is_neg -> !llvm.ptr<i8> {
scf.yield %negative_str : !llvm.ptr<i8>
} else {
scf.yield %many_str : !llvm.ptr<i8>
}
scf.yield %r2 : !llvm.ptr<i8>
}
최적화 효과:
- Before: O(n) sequential comparisons
- After: O(1) jump table for dense range
Wildcard Default Case 최적화
Wildcard _는 테스트를 생성하지 않는다:
match x with
| 0 -> handle_zero()
| 1 -> handle_one()
| _ -> handle_default() // No comparison needed!
%is_zero = arith.cmpi eq, %x, %c0 : i32
scf.if %is_zero {
// case 0
} else {
%is_one = arith.cmpi eq, %x, %c1 : i32
scf.if %is_one {
// case 1
} else {
// _ case: NO arith.cmpi, just fallthrough
// All other cases exhausted, this is the default
}
}
핵심 원칙:
- 마지막 else branch는 이전 모든 테스트가 실패한 경우
- 추가 비교 없이 바로 default 코드 실행
- 이것이 wildcard의 zero-cost abstraction
리터럴 + Constructor 혼합 예제
리스트와 숫자를 함께 매칭:
let take_first_n lst n =
match (lst, n) with
| (_, 0) -> []
| ([], _) -> []
| (head :: tail, n) -> head :: take_first_n tail (n - 1)
컴파일된 MLIR:
func.func @take_first_n(%lst: !funlang.list<i32>, %n: i32) -> !funlang.list<i32> {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
// Check n == 0 first (literal pattern)
%is_n_zero = arith.cmpi eq, %n, %c0 : i32
%result = scf.if %is_n_zero -> !funlang.list<i32> {
// Case (_, 0): return empty
%empty = funlang.nil : !funlang.list<i32>
scf.yield %empty : !funlang.list<i32>
} else {
// Check list constructor (constructor pattern)
%struct = builtin.unrealized_conversion_cast %lst : ... to !llvm.struct<(i32, ptr)>
%tag = llvm.extractvalue %struct[0] : !llvm.struct<(i32, ptr)>
%tag_index = arith.index_cast %tag : i32 to index
%inner = scf.index_switch %tag_index : index -> !funlang.list<i32>
case 0 {
// Case ([], _): return empty
%empty = funlang.nil : !funlang.list<i32>
scf.yield %empty : !funlang.list<i32>
}
case 1 {
// Case (head :: tail, n): recursive
%data = llvm.extractvalue %struct[1] : !llvm.struct<(i32, ptr)>
%head = llvm.load %data : !llvm.ptr -> i32
%tail_ptr = llvm.getelementptr %data[1] : (!llvm.ptr) -> !llvm.ptr
%tail = llvm.load %tail_ptr : !llvm.ptr -> !funlang.list<i32>
%n_minus_1 = arith.subi %n, %c1 : i32
%rest = func.call @take_first_n(%tail, %n_minus_1) : (...) -> !funlang.list<i32>
%new_list = funlang.cons %head, %rest : ...
scf.yield %new_list : !funlang.list<i32>
}
default { scf.yield %unreachable : !funlang.list<i32> }
scf.yield %inner : !funlang.list<i32>
}
return %result : !funlang.list<i32>
}
혼합 패턴 lowering 전략:
- Literal column first:
arith.cmpi+scf.if - Constructor column inside:
scf.index_switch - Wildcard: test 없이 fallthrough
검증 및 테스트
let testFizzBuzz() =
// Test fizzbuzz
assert (fizzbuzz 3 = "Fizz")
assert (fizzbuzz 5 = "Buzz")
assert (fizzbuzz 15 = "FizzBuzz")
assert (fizzbuzz 7 = "7")
printfn "fizzbuzz tests passed"
// Test classify
assert (classify 0 = "zero")
assert (classify 1 = "one")
assert (classify 2 = "two")
assert (classify 42 = "many")
assert (classify (-5) = "negative")
printfn "classify tests passed"
// Test take_first_n
assert (take_first_n [1, 2, 3, 4, 5] 3 = [1, 2, 3])
assert (take_first_n [1, 2, 3] 0 = [])
assert (take_first_n [] 5 = [])
printfn "take_first_n tests passed"
Output:
fizzbuzz tests passed
classify tests passed
take_first_n tests passed
Key Takeaways
- 리터럴 패턴:
arith.cmpi eq+scf.ifchain - Constructor 패턴:
scf.index_switch로 O(1) dispatch - Wildcard: else branch로 fallthrough (테스트 없음)
- Dense range:
scf.index_switch로 최적화 가능 - 혼합 패턴: 각 column의 패턴 타입에 맞는 dispatch 사용
튜플 예제: zip과 unzip (Tuple Examples: zip and unzip)
Chapter 18에서 !funlang.tuple<T1, T2, ...> 타입과 funlang.make_tuple 연산을, Chapter 19에서 튜플 패턴 매칭을 구현했다. 이제 튜플을 활용하는 실제 프로그램을 작성하고 컴파일해보자.
zip 함수: 두 리스트를 쌍의 리스트로
zip의 개념:
두 리스트를 받아 각 위치의 원소들을 튜플로 묶은 리스트를 반환한다.
// zip의 타입
zip : [a] -> [b] -> [(a, b)]
// zip의 동작
zip [1, 2, 3] ["a", "b", "c"] = [(1, "a"), (2, "b"), (3, "c")]
// 길이가 다르면 짧은 쪽에 맞춤
zip [1, 2] ["a", "b", "c"] = [(1, "a"), (2, "b")]
FunLang 구현:
let rec zip xs ys =
match xs with
| [] -> []
| x :: xs' ->
match ys with
| [] -> []
| y :: ys' -> make_tuple(x, y) :: zip xs' ys'
동작 원리:
- 첫 번째 리스트가 비어있으면 빈 리스트 반환
- 두 번째 리스트가 비어있으면 빈 리스트 반환
- 둘 다 원소가 있으면:
- 각 head로 튜플 생성:
make_tuple(x, y) - tail들로 재귀 호출:
zip xs' ys' - 결과를 cons:
pair :: rest
- 각 head로 튜플 생성:
zip 함수 컴파일: FunLang MLIR
// zip : !funlang.list<i32> -> !funlang.list<f64> -> !funlang.list<!funlang.tuple<i32, f64>>
func.func @zip(%xs: !funlang.list<i32>, %ys: !funlang.list<f64>)
-> !funlang.list<!funlang.tuple<i32, f64>> {
// 첫 번째 리스트 패턴 매칭
%result = funlang.match %xs : !funlang.list<i32>
-> !funlang.list<!funlang.tuple<i32, f64>> {
^nil:
// xs가 비어있으면 빈 리스트 반환
%empty = funlang.nil : !funlang.list<!funlang.tuple<i32, f64>>
funlang.yield %empty : !funlang.list<!funlang.tuple<i32, f64>>
^cons(%x: i32, %xs_tail: !funlang.list<i32>):
// xs = x :: xs_tail, 이제 ys 패턴 매칭
%inner = funlang.match %ys : !funlang.list<f64>
-> !funlang.list<!funlang.tuple<i32, f64>> {
^nil:
// ys가 비어있으면 빈 리스트 반환
%empty2 = funlang.nil : !funlang.list<!funlang.tuple<i32, f64>>
funlang.yield %empty2 : !funlang.list<!funlang.tuple<i32, f64>>
^cons(%y: f64, %ys_tail: !funlang.list<f64>):
// ys = y :: ys_tail
// 튜플 생성: (x, y)
%pair = funlang.make_tuple(%x, %y) : !funlang.tuple<i32, f64>
// 재귀 호출: zip xs_tail ys_tail
%rest = func.call @zip(%xs_tail, %ys_tail)
: (!funlang.list<i32>, !funlang.list<f64>)
-> !funlang.list<!funlang.tuple<i32, f64>>
// cons: pair :: rest
%cons_result = funlang.cons %pair, %rest
: !funlang.list<!funlang.tuple<i32, f64>>
funlang.yield %cons_result : !funlang.list<!funlang.tuple<i32, f64>>
}
funlang.yield %inner : !funlang.list<!funlang.tuple<i32, f64>>
}
return %result : !funlang.list<!funlang.tuple<i32, f64>>
}
핵심 포인트:
- 중첩 패턴 매칭: 먼저 xs를 매칭하고, Cons case 안에서 ys를 매칭
- make_tuple 사용:
funlang.make_tuple(%x, %y)로 쌍 생성 - 결과 타입:
!funlang.list<!funlang.tuple<i32, f64>>- 튜플의 리스트
fst와 snd 함수: 튜플 원소 추출
fst와 snd의 정의:
// 첫 번째 원소 추출
let fst pair = match pair with (x, _) -> x
// 두 번째 원소 추출
let snd pair = match pair with (_, y) -> y
MLIR 구현:
// fst : !funlang.tuple<i32, f64> -> i32
func.func @fst(%pair: !funlang.tuple<i32, f64>) -> i32 {
%result = funlang.match %pair : !funlang.tuple<i32, f64> -> i32 {
^case(%x: i32, %y: f64):
funlang.yield %x : i32
}
return %result : i32
}
// snd : !funlang.tuple<i32, f64> -> f64
func.func @snd(%pair: !funlang.tuple<i32, f64>) -> f64 {
%result = funlang.match %pair : !funlang.tuple<i32, f64> -> f64 {
^case(%x: i32, %y: f64):
funlang.yield %y : f64
}
return %result : f64
}
Lowering 결과:
// fst after lowering - 분기 없이 직접 추출
func.func @fst(%pair: !llvm.struct<(i32, f64)>) -> i32 {
%x = llvm.extractvalue %pair[0] : !llvm.struct<(i32, f64)>
return %x : i32
}
// snd after lowering
func.func @snd(%pair: !llvm.struct<(i32, f64)>) -> f64 {
%y = llvm.extractvalue %pair[1] : !llvm.struct<(i32, f64)>
return %y : f64
}
핵심:
- 튜플 패턴 매칭은 scf.index_switch 없이 바로 extractvalue로 lowering
- 와일드카드
_는 해당 위치의 extractvalue를 생략 (dead code elimination)
unzip 함수: 쌍의 리스트를 두 리스트로
unzip의 개념:
zip의 역연산. 튜플 리스트를 두 개의 리스트로 분리한다.
// unzip의 타입
unzip : [(a, b)] -> ([a], [b])
// unzip의 동작
unzip [(1, "a"), (2, "b")] = ([1, 2], ["a", "b"])
FunLang 구현:
let rec unzip pairs =
match pairs with
| [] -> ([], [])
| p :: ps ->
let (x, y) = p in
let (xs, ys) = unzip ps in
(x :: xs, y :: ys)
MLIR 구현:
// unzip : !funlang.list<!funlang.tuple<i32, f64>>
// -> !funlang.tuple<!funlang.list<i32>, !funlang.list<f64>>
func.func @unzip(%pairs: !funlang.list<!funlang.tuple<i32, f64>>)
-> !funlang.tuple<!funlang.list<i32>, !funlang.list<f64>> {
%result = funlang.match %pairs
: !funlang.list<!funlang.tuple<i32, f64>>
-> !funlang.tuple<!funlang.list<i32>, !funlang.list<f64>> {
^nil:
// 빈 리스트 → ([], [])
%empty_ints = funlang.nil : !funlang.list<i32>
%empty_floats = funlang.nil : !funlang.list<f64>
%empty_pair = funlang.make_tuple(%empty_ints, %empty_floats)
: !funlang.tuple<!funlang.list<i32>, !funlang.list<f64>>
funlang.yield %empty_pair
: !funlang.tuple<!funlang.list<i32>, !funlang.list<f64>>
^cons(%p: !funlang.tuple<i32, f64>, %ps: !funlang.list<!funlang.tuple<i32, f64>>):
// p = (x, y), 튜플 분해
%xy = funlang.match %p : !funlang.tuple<i32, f64>
-> !funlang.tuple<i32, f64> {
^case(%x: i32, %y: f64):
funlang.yield %p : !funlang.tuple<i32, f64>
}
// 실제로는 직접 extractvalue 사용
%x = ... extractvalue [0] ...
%y = ... extractvalue [1] ...
// 재귀: unzip ps
%rest = func.call @unzip(%ps) : ...
%xs = ... fst rest ...
%ys = ... snd rest ...
// 결과: (x :: xs, y :: ys)
%new_xs = funlang.cons %x, %xs : !funlang.list<i32>
%new_ys = funlang.cons %y, %ys : !funlang.list<f64>
%result_pair = funlang.make_tuple(%new_xs, %new_ys)
: !funlang.tuple<!funlang.list<i32>, !funlang.list<f64>>
funlang.yield %result_pair
: !funlang.tuple<!funlang.list<i32>, !funlang.list<f64>>
}
return %result : !funlang.tuple<!funlang.list<i32>, !funlang.list<f64>>
}
Point 조작 예제: 2D 좌표
Point 타입:
// Point = (int, int) 튜플
type point = int * int
let origin = (0, 0)
let p1 = (3, 4)
기본 연산들:
// 오른쪽으로 이동
let move_right pt =
match pt with (x, y) -> (x + 1, y)
// 위로 이동
let move_up pt =
match pt with (x, y) -> (x, y + 1)
// 두 점 사이의 거리 (맨해튼)
let manhattan_distance p1 p2 =
match (p1, p2) with ((x1, y1), (x2, y2)) ->
abs(x2 - x1) + abs(y2 - y1)
// 점 리스트의 중심점
let centroid points =
let sum_pts = fold (fun (sx, sy) (x, y) -> (sx + x, sy + y)) (0, 0) points
let n = length points
match sum_pts with (sx, sy) -> (sx / n, sy / n)
MLIR 구현 - move_right:
// move_right : !funlang.tuple<i32, i32> -> !funlang.tuple<i32, i32>
func.func @move_right(%pt: !funlang.tuple<i32, i32>) -> !funlang.tuple<i32, i32> {
%result = funlang.match %pt : !funlang.tuple<i32, i32> -> !funlang.tuple<i32, i32> {
^case(%x: i32, %y: i32):
%c1 = arith.constant 1 : i32
%new_x = arith.addi %x, %c1 : i32
%new_pt = funlang.make_tuple(%new_x, %y) : !funlang.tuple<i32, i32>
funlang.yield %new_pt : !funlang.tuple<i32, i32>
}
return %result : !funlang.tuple<i32, i32>
}
Lowering 결과:
func.func @move_right(%pt: !llvm.struct<(i32, i32)>) -> !llvm.struct<(i32, i32)> {
%x = llvm.extractvalue %pt[0] : !llvm.struct<(i32, i32)>
%y = llvm.extractvalue %pt[1] : !llvm.struct<(i32, i32)>
%c1 = arith.constant 1 : i32
%new_x = arith.addi %x, %c1 : i32
%0 = llvm.mlir.undef : !llvm.struct<(i32, i32)>
%1 = llvm.insertvalue %new_x, %0[0] : !llvm.struct<(i32, i32)>
%result = llvm.insertvalue %y, %1[1] : !llvm.struct<(i32, i32)>
return %result : !llvm.struct<(i32, i32)>
}
중첩 튜플 - manhattan_distance:
// manhattan_distance : !funlang.tuple<i32, i32> -> !funlang.tuple<i32, i32> -> i32
func.func @manhattan_distance(%p1: !funlang.tuple<i32, i32>, %p2: !funlang.tuple<i32, i32>) -> i32 {
// 두 점을 하나의 튜플로 묶어서 패턴 매칭
%combined = funlang.make_tuple(%p1, %p2)
: !funlang.tuple<!funlang.tuple<i32, i32>, !funlang.tuple<i32, i32>>
// 중첩 튜플 분해
%result = funlang.match %combined
: !funlang.tuple<!funlang.tuple<i32, i32>, !funlang.tuple<i32, i32>> -> i32 {
^case(%pt1: !funlang.tuple<i32, i32>, %pt2: !funlang.tuple<i32, i32>):
// 첫 번째 점 분해
%xy1 = funlang.match %pt1 : !funlang.tuple<i32, i32> -> !funlang.tuple<i32, i32> {
^case(%x1: i32, %y1: i32):
funlang.yield %pt1 : !funlang.tuple<i32, i32>
}
// 실제로는 extractvalue 연쇄
// %x1 = extractvalue %pt1[0]
// %y1 = extractvalue %pt1[1]
// %x2 = extractvalue %pt2[0]
// %y2 = extractvalue %pt2[1]
// 거리 계산
// %dx = abs(x2 - x1)
// %dy = abs(y2 - y1)
// %result = dx + dy
...
funlang.yield %distance : i32
}
return %result : i32
}
튜플 + 고차 함수 결합
튜플을 사용한 map_with_index:
// 리스트의 각 원소에 인덱스와 함께 함수 적용
let map_with_index f lst =
let indexed = zip [0..length lst - 1] lst
map (fun (i, x) -> f i x) indexed
enumerate 함수:
// 리스트에 인덱스를 붙여서 튜플 리스트로
let rec enumerate_from n lst =
match lst with
| [] -> []
| x :: xs -> (n, x) :: enumerate_from (n + 1) xs
let enumerate = enumerate_from 0
// 사용 예
enumerate ["a", "b", "c"] // [(0, "a"), (1, "b"), (2, "c")]
partition 함수 (튜플 반환):
// 리스트를 조건에 따라 두 리스트로 분리
let rec partition pred lst =
match lst with
| [] -> ([], [])
| x :: xs ->
let (yes, no) = partition pred xs
if pred x then
(x :: yes, no)
else
(yes, x :: no)
// 사용 예
partition (fun x -> x > 0) [-1, 2, -3, 4] // ([2, 4], [-1, -3])
MLIR 구현 - partition:
func.func @partition(%pred: !funlang.closure<(i32) -> i1>,
%lst: !funlang.list<i32>)
-> !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>> {
%result = funlang.match %lst : !funlang.list<i32>
-> !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>> {
^nil:
%empty1 = funlang.nil : !funlang.list<i32>
%empty2 = funlang.nil : !funlang.list<i32>
%pair = funlang.make_tuple(%empty1, %empty2)
: !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>>
funlang.yield %pair : !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>>
^cons(%x: i32, %xs: !funlang.list<i32>):
// 재귀: partition pred xs
%rest = func.call @partition(%pred, %xs) : ...
%yes = ... fst rest ...
%no = ... snd rest ...
// pred x 평가
%test = funlang.apply %pred(%x) : (i32) -> i1
// 조건부 cons
%new_pair = scf.if %test -> !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>> {
%new_yes = funlang.cons %x, %yes : !funlang.list<i32>
%pair = funlang.make_tuple(%new_yes, %no)
: !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>>
scf.yield %pair : !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>>
} else {
%new_no = funlang.cons %x, %no : !funlang.list<i32>
%pair = funlang.make_tuple(%yes, %new_no)
: !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>>
scf.yield %pair : !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>>
}
funlang.yield %new_pair : !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>>
}
return %result : !funlang.tuple<!funlang.list<i32>, !funlang.list<i32>>
}
Summary: 튜플 예제
구현한 함수들:
| 함수 | 타입 | 설명 |
|---|---|---|
| zip | [a] -> [b] -> [(a,b)] | 두 리스트를 쌍으로 묶기 |
| fst | (a,b) -> a | 첫 번째 원소 추출 |
| snd | (a,b) -> b | 두 번째 원소 추출 |
| unzip | [(a,b)] -> ([a], [b]) | 쌍 리스트를 두 리스트로 분리 |
| move_right | point -> point | 좌표 변환 |
| manhattan_distance | point -> point -> int | 두 점 사이 거리 |
| enumerate | [a] -> [(int, a)] | 인덱스 붙이기 |
| partition | (a -> bool) -> [a] -> ([a], [a]) | 조건에 따라 분리 |
핵심 패턴:
- make_tuple로 튜플 생성:
funlang.make_tuple(%a, %b) - 패턴 매칭으로 분해:
^case(%x, %y):또는 extractvalue 직접 사용 - 중첩 가능: 튜플 안에 리스트, 리스트 안에 튜플
- 다중 반환값: 함수에서 튜플 반환하여 여러 값 리턴
- 고차 함수와 결합: map, fold 등과 함께 사용
Lowering 특성:
- 튜플 패턴: 분기 없이 extractvalue 체인
- 리스트 패턴: scf.index_switch 사용
- 중첩: 외부에서 내부로 순차 처리
Phase 6 Complete Summary
축하한다! Phase 6를 완료했다.
Chapter 17-20 복습
Chapter 17: Pattern Matching Theory
- Decision tree 알고리즘으로 패턴 매칭을 효율적으로 컴파일
- Exhaustiveness checking으로 빠진 case 감지
- Unreachable case detection으로 중복 제거
Chapter 18: List Operations
!funlang.list<T>parameterized type- Tagged union representation:
!llvm.struct<(i32, ptr)> funlang.nil과funlang.consoperations- TypeConverter와 lowering patterns
Chapter 19: Match Compilation
funlang.matchoperation 정의- Multi-stage lowering: FunLang → SCF → CF → LLVM
- IRMapping으로 block argument remapping
- Region-based IR structure
Chapter 20: Functional Programs (this chapter)
- FunLang AST extensions for lists
- Compiler integration (compileExpr, type inference)
- Core list functions: map, filter, fold, length, append
- Complete example: sum_of_squares
- End-to-end compilation pipeline (9 stages)
- Performance analysis and optimization preview
What You Can Now Compile
Phase 6 종료 시점에 컴파일 가능한 프로그램:
// 1. List construction
let list = [1, 2, 3, 4, 5]
// 2. Pattern matching
let rec sum lst =
match lst with
| [] -> 0
| head :: tail -> head + sum tail
// 3. Higher-order functions
let map f lst = ...
let filter pred lst = ...
let fold combiner acc lst = ...
// 4. Function composition
let sum_of_squares lst =
fold (+) 0 (map (fun x -> x * x) lst)
// 5. Complex functional programs
let process data =
data
|> filter is_valid
|> map transform
|> fold aggregate initial
// 6. Nested data structures
let nested = [[1, 2], [3, 4], [5, 6]]
let flattened = fold append [] nested
이것은 실제 함수형 언어와 동등한 표현력이다!
Technical Achievements
Phase 6에서 구현한 기술:
- Parameterized types:
!funlang.list<T>with element type parameter - Tagged unions: Efficient runtime representation of ADTs
- Pattern matching: Decision tree compilation for performance
- Multi-stage lowering: Progressive refinement through dialects
- Type conversion: Consistent type mapping across lowering stages
- Region-based IR: Structured control flow with scoped bindings
- Tail recursion: Optimization opportunity for fold
- GC integration: Automatic memory management for lists
- Complete pipeline: Source → AST → MLIR → LLVM IR → Machine code
Phase 7 Preview: Optimization
Phase 7에서 다룰 내용:
1. List Fusion
중간 리스트 제거:
// Before
map f (map g lst) // Two passes, intermediate list
// After fusion
map (f << g) lst // One pass, no intermediate
2. Deforestation
Tree 구조 중간 생성 제거:
// Before
fold h z (map f lst) // Creates intermediate list
// After deforestation
fold (fun acc x -> h acc (f x)) z lst // Direct
3. Inlining
Small 함수 inline:
// Before
%result = func.call @square(%x) : (i32) -> i32
// After inlining
%result = arith.muli %x, %x : i32
4. Loop Unrolling
재귀를 explicit loop로 변환:
// Before (recursive)
func.func @map(...) {
%result = funlang.match %lst : ... {
^nil: ...
^cons(...): %mapped = func.call @map(...) ...
}
}
// After (loop)
func.func @map(...) {
scf.for %i = 0 to %n step 1 iter_args(%acc = %init) -> ... {
%elem = load %lst[%i]
%transformed = apply %f(%elem)
...
}
}
5. Parallel Map
데이터 병렬성 활용:
scf.parallel (%i) = (0) to (%n) step (1) {
%elem = load %lst[%i]
%result = apply %f(%elem)
store %result, %output[%i]
}
6. Constant Folding
컴파일 시간에 계산:
// Before
let result = sum [1, 2, 3, 4, 5]
// After constant folding
let result = 15 // Computed at compile time
이러한 최적화는 MLIR의 transformation passes로 구현되며, Phase 7에서 자세히 다룬다.
Congratulations!
Phase 6 완료를 축하한다!
이제 여러분은:
- ✓ 완전한 함수형 프로그래밍 언어를 컴파일할 수 있다
- ✓ 리스트, 패턴 매칭, 고차 함수를 지원한다
- ✓ Multi-stage lowering pipeline을 이해한다
- ✓ End-to-end 컴파일 (source to machine code)을 할 수 있다
- ✓ 성능 특성과 최적화 기회를 안다
다음 단계: Phase 7 (Optimization)에서 더 빠르고 효율적인 코드 생성을 배운다.
Happy functional programming! 🎉