sammine-lang
Loading...
Searching...
No Matches
BiTypeChecker.h
Go to the documentation of this file.
1#pragma once
2
3#include "ast/Ast.h"
4#include "ast/ASTProperties.h"
5#include "ast/AstBase.h"
6#include "typecheck/Monomorphizer.h"
7#include "typecheck/Types.h"
9#include <set>
10#include <unordered_map>
11
16namespace sammine_lang {
17
18namespace AST {
19
20class TypingContext : public LexicalContext<Type, AST::FuncDefAST *> {};
21class BiTypeCheckerVisitor : public ScopedASTVisitor,
22 public TypeCheckerVisitor {
26
27 // We're gonna provide look up in different
28
29public:
31 // INFO: x, y, z
33
34 // INFO: i64, f64 bla bla bla
36 TypeMapOrdering type_map_ordering;
37
38 // Generic function support
39 std::unordered_map<std::string, FuncDefAST *> generic_func_defs;
40 std::vector<std::unique_ptr<FuncDefAST>> monomorphized_defs;
41 std::set<std::string> instantiated_functions;
42
43 // Generic enum support
44 std::unordered_map<std::string, EnumDefAST *> generic_enum_defs;
45 std::vector<std::unique_ptr<EnumDefAST>> monomorphized_enum_defs;
46 std::set<std::string> instantiated_enums;
47
48 // Unification and substitution helpers
49 bool unify(const Type &pattern, const Type &concrete,
50 std::unordered_map<std::string, Type> &bindings);
51 Type substitute(const Type &type,
52 const std::unordered_map<std::string, Type> &bindings) const;
53 bool contains_type_param(const Type &type, const std::string &param_name);
54
55 virtual void enter_new_scope() override {
56 id_to_type.push_context();
57 typename_to_type.push_context();
58
59 typename_to_type.registerNameT("i32", Type::I32_t());
60 typename_to_type.registerNameT("i64", Type::I64_t());
61 typename_to_type.registerNameT("u32", Type::U32_t());
62 typename_to_type.registerNameT("u64", Type::U64_t());
63 typename_to_type.registerNameT("f64", Type::F64_t());
64 typename_to_type.registerNameT("bool", Type::Bool());
65 typename_to_type.registerNameT("char", Type::Char());
66 typename_to_type.registerNameT("unit", Type::Unit());
67 }
68 virtual void exit_new_scope() override {
69 id_to_type.pop();
70 typename_to_type.pop();
71 }
73 this->enter_new_scope();
74 }
75
76 std::optional<Type> get_type_from_id(const std::string &str) const {
77
78 const auto &id_name_top = id_to_type.top();
79 if (id_name_top.queryName(str) == nameNotFound) {
80 sammine_util::abort(
81 fmt::format("Name '{}' not found, this should not happen", str));
82 }
83 return id_name_top.get_from_name(str);
84 }
85
86 std::optional<Type> get_type_from_id_parent(const std::string &str) const {
87
88 const auto &id_name_top = *id_to_type.top().parent_scope;
89 if (id_name_top.queryName(str) == nameNotFound) {
90 sammine_util::abort(
91 fmt::format("Name '{}' not found, this should not happen", str));
92 }
93 return id_name_top.get_from_name(str);
94 }
95
96 std::optional<Type> get_typename_type(const std::string &str) const {
97 const auto &typename_top = typename_to_type.top();
98 if (typename_top.recursiveQueryName(str) == nameNotFound) {
99 return std::nullopt;
100 }
101 return typename_top.recursive_get_from_name(str);
102 }
103
105 std::optional<Type> try_get_callee_type(const std::string &str) const {
106 const auto &top = id_to_type.top();
107 if (top.recursiveQueryName(str) == nameFound)
108 return top.recursive_get_from_name(str);
109 return std::nullopt;
110 }
111
112 // visit overrides
113 virtual void visit(ProgramAST *ast) override;
114 virtual void visit(VarDefAST *ast) override;
115 virtual void visit(ExternAST *ast) override;
116 virtual void visit(FuncDefAST *ast) override;
117 virtual void visit(StructDefAST *ast) override;
118 virtual void visit(EnumDefAST *ast) override;
119 virtual void visit(TypeAliasDefAST *ast) override;
120 virtual void visit(PrototypeAST *ast) override;
121 virtual void visit(CallExprAST *ast) override;
122 virtual void visit(ReturnExprAST *ast) override;
123 virtual void visit(BinaryExprAST *ast) override;
124 virtual void visit(NumberExprAST *ast) override;
125 virtual void visit(StringExprAST *ast) override;
126 virtual void visit(BoolExprAST *ast) override;
127 virtual void visit(CharExprAST *ast) override;
128 virtual void visit(UnitExprAST *ast) override;
129 virtual void visit(VariableExprAST *ast) override;
130 virtual void visit(BlockAST *ast) override;
131 virtual void visit(IfExprAST *ast) override;
132 virtual void visit(TypedVarAST *ast) override;
133 virtual void visit(DerefExprAST *ast) override;
134 virtual void visit(AddrOfExprAST *ast) override;
135 virtual void visit(AllocExprAST *ast) override;
136 virtual void visit(FreeExprAST *ast) override;
137 virtual void visit(ArrayLiteralExprAST *ast) override;
138 virtual void visit(IndexExprAST *ast) override;
139 virtual void visit(LenExprAST *ast) override;
140 virtual void visit(UnaryNegExprAST *ast) override;
141 virtual void visit(StructLiteralExprAST *ast) override;
142 virtual void visit(FieldAccessExprAST *ast) override;
143 virtual void visit(CaseExprAST *ast) override;
144 virtual void visit(WhileExprAST *ast) override;
145 virtual void visit(TupleLiteralExprAST *ast) override;
146 virtual void visit(TypeClassDeclAST *ast) override;
147 virtual void visit(TypeClassInstanceAST *ast) override;
148
149 // Type class data structures
151 std::string name;
152 std::string type_param;
153 std::vector<PrototypeAST *> methods;
154 };
155
157 std::string class_name;
158 Type concrete_type;
159 std::unordered_map<std::string, sammine_util::MonomorphizedName>
160 method_mangled_names;
161 };
162
163 std::unordered_map<std::string, TypeClassInfo> type_class_defs;
164 std::unordered_map<std::string, TypeClassInstanceInfo> type_class_instances;
165 std::unordered_map<std::string, std::string> method_to_class;
166
167 // Enum variant constructors: variant_name → (enum_type, variant_index)
168 std::unordered_map<std::string, std::pair<Type, size_t>> variant_constructors;
169
170 // Pre-register a function signature so later definitions can reference it
171 void pre_register_function(PrototypeAST *ast);
172
173 // Two-pass typeclass registration (called before full type checking)
174 void register_typeclass_decl(TypeClassDeclAST *ast);
175 void register_typeclass_instance(TypeClassInstanceAST *ast);
176 void register_builtin_op_instances();
177
178 // Call expression synthesis helpers
179 std::optional<Type> synthesize_typeclass_call(CallExprAST *ast);
180 Type synthesize_generic_call(CallExprAST *ast);
181 Type synthesize_normal_call(CallExprAST *ast);
182
183 // Binary expression synthesis helper
184 Type synthesize_binary_operator(BinaryExprAST *ast, const Type &lhs_type,
185 const Type &rhs_type);
186
187 // VarDef array checking helper
188 bool check_array_literal_against_annotation(VarDefAST *ast,
189 ArrayLiteralExprAST *arr_lit,
190 const ArrayType &arr_type);
191
192 // pre order
193
194 virtual void preorder_walk(ProgramAST *ast) override;
195 virtual void preorder_walk(VarDefAST *ast) override;
196 virtual void preorder_walk(ExternAST *ast) override;
197 virtual void preorder_walk(FuncDefAST *ast) override;
198 virtual void preorder_walk(StructDefAST *ast) override;
199 virtual void preorder_walk(EnumDefAST *ast) override;
200 virtual void preorder_walk(TypeAliasDefAST *ast) override;
201 virtual void preorder_walk(PrototypeAST *ast) override;
202 virtual void preorder_walk(CallExprAST *ast) override;
203 virtual void preorder_walk(ReturnExprAST *ast) override;
204 virtual void preorder_walk(BinaryExprAST *ast) override;
205 virtual void preorder_walk(NumberExprAST *ast) override;
206 virtual void preorder_walk(StringExprAST *ast) override;
207 virtual void preorder_walk(BoolExprAST *ast) override;
208 virtual void preorder_walk(CharExprAST *ast) override;
209 virtual void preorder_walk(UnitExprAST *ast) override;
210 virtual void preorder_walk(VariableExprAST *ast) override;
211 virtual void preorder_walk(BlockAST *ast) override;
212 virtual void preorder_walk(IfExprAST *ast) override;
213 virtual void preorder_walk(TypedVarAST *ast) override;
214 virtual void preorder_walk(DerefExprAST *ast) override;
215 virtual void preorder_walk(AddrOfExprAST *ast) override;
216 virtual void preorder_walk(AllocExprAST *ast) override;
217 virtual void preorder_walk(FreeExprAST *ast) override;
218 virtual void preorder_walk(ArrayLiteralExprAST *ast) override;
219 virtual void preorder_walk(IndexExprAST *ast) override;
220 virtual void preorder_walk(LenExprAST *ast) override;
221 virtual void preorder_walk(UnaryNegExprAST *ast) override;
222 virtual void preorder_walk(StructLiteralExprAST *ast) override;
223 virtual void preorder_walk(FieldAccessExprAST *ast) override;
224 virtual void preorder_walk(CaseExprAST *ast) override;
225 virtual void preorder_walk(WhileExprAST *ast) override;
226 virtual void preorder_walk(TupleLiteralExprAST *ast) override;
227 virtual void preorder_walk(TypeClassDeclAST *ast) override;
228 virtual void preorder_walk(TypeClassInstanceAST *ast) override;
229
230 // post order
231 virtual void postorder_walk(ProgramAST *ast) override;
232 virtual void postorder_walk(VarDefAST *ast) override;
233 virtual void postorder_walk(ExternAST *ast) override;
234 virtual void postorder_walk(FuncDefAST *ast) override;
235 virtual void postorder_walk(StructDefAST *ast) override;
236 virtual void postorder_walk(EnumDefAST *ast) override;
237 virtual void postorder_walk(TypeAliasDefAST *ast) override;
238 virtual void postorder_walk(PrototypeAST *ast) override;
239 virtual void postorder_walk(CallExprAST *ast) override;
240 virtual void postorder_walk(ReturnExprAST *ast) override;
241 virtual void postorder_walk(BinaryExprAST *ast) override;
242 virtual void postorder_walk(NumberExprAST *ast) override;
243 virtual void postorder_walk(StringExprAST *ast) override;
244 virtual void postorder_walk(BoolExprAST *ast) override;
245 virtual void postorder_walk(CharExprAST *ast) override;
246 virtual void postorder_walk(UnitExprAST *ast) override;
247 virtual void postorder_walk(VariableExprAST *ast) override;
248 virtual void postorder_walk(BlockAST *ast) override;
249 virtual void postorder_walk(IfExprAST *ast) override;
250 virtual void postorder_walk(TypedVarAST *ast) override;
251 virtual void postorder_walk(DerefExprAST *ast) override;
252 virtual void postorder_walk(AddrOfExprAST *ast) override;
253 virtual void postorder_walk(AllocExprAST *ast) override;
254 virtual void postorder_walk(FreeExprAST *ast) override;
255 virtual void postorder_walk(ArrayLiteralExprAST *ast) override;
256 virtual void postorder_walk(IndexExprAST *ast) override;
257 virtual void postorder_walk(LenExprAST *ast) override;
258 virtual void postorder_walk(UnaryNegExprAST *ast) override;
259 virtual void postorder_walk(StructLiteralExprAST *ast) override;
260 virtual void postorder_walk(FieldAccessExprAST *ast) override;
261 virtual void postorder_walk(CaseExprAST *ast) override;
262 virtual void postorder_walk(WhileExprAST *ast) override;
263 virtual void postorder_walk(TupleLiteralExprAST *ast) override;
264 virtual void postorder_walk(TypeClassDeclAST *ast) override;
265 virtual void postorder_walk(TypeClassInstanceAST *ast) override;
266
267 virtual Type synthesize(ProgramAST *ast) override;
268 virtual Type synthesize(VarDefAST *ast) override;
269 virtual Type synthesize(ExternAST *ast) override;
270 virtual Type synthesize(FuncDefAST *ast) override;
271 virtual Type synthesize(StructDefAST *ast) override;
272 virtual Type synthesize(EnumDefAST *ast) override;
273 virtual Type synthesize(TypeAliasDefAST *ast) override;
274 virtual Type synthesize(PrototypeAST *ast) override;
275 virtual Type synthesize(CallExprAST *ast) override;
276 virtual Type synthesize(ReturnExprAST *ast) override;
277 virtual Type synthesize(BinaryExprAST *ast) override;
278 virtual Type synthesize(NumberExprAST *ast) override;
279 virtual Type synthesize(UnitExprAST *ast) override;
280 virtual Type synthesize(StringExprAST *ast) override;
281 virtual Type synthesize(BoolExprAST *ast) override;
282 virtual Type synthesize(CharExprAST *ast) override;
283 virtual Type synthesize(VariableExprAST *ast) override;
284 virtual Type synthesize(BlockAST *ast) override;
285 virtual Type synthesize(IfExprAST *ast) override;
286 virtual Type synthesize(TypedVarAST *ast) override;
287 virtual Type synthesize(DerefExprAST *ast) override;
288 virtual Type synthesize(AddrOfExprAST *ast) override;
289 virtual Type synthesize(AllocExprAST *ast) override;
290 virtual Type synthesize(FreeExprAST *ast) override;
291 virtual Type synthesize(ArrayLiteralExprAST *ast) override;
292 virtual Type synthesize(IndexExprAST *ast) override;
293 virtual Type synthesize(LenExprAST *ast) override;
294 virtual Type synthesize(UnaryNegExprAST *ast) override;
295 virtual Type synthesize(StructLiteralExprAST *ast) override;
296 virtual Type synthesize(FieldAccessExprAST *ast) override;
297 virtual Type synthesize(CaseExprAST *ast) override;
298 virtual Type synthesize(WhileExprAST *ast) override;
299 virtual Type synthesize(TupleLiteralExprAST *ast) override;
300 virtual Type synthesize(TypeClassDeclAST *ast) override;
301 virtual Type synthesize(TypeClassInstanceAST *ast) override;
302
303 Type resolve_type_expr(TypeExprAST *type_expr) {
304 if (!type_expr)
305 return Type::NonExistent();
306
307 if (auto *simple = llvm::dyn_cast<SimpleTypeExprAST>(type_expr)) {
308 if (simple->name.is_unresolved()) {
309 this->add_error(type_expr->location,
310 fmt::format("Module '{}' is not imported",
311 simple->name.get_module()));
312 return Type::Poisoned();
313 }
314 auto mangled = simple->name.mangled();
315 auto get_type_opt = this->get_typename_type(mangled);
316 if (!get_type_opt.has_value()) {
317 this->add_error(type_expr->location,
318 fmt::format("Type '{}' not found in the current scope.",
319 simple->name.mangled()));
320 return Type::Poisoned();
321 }
322 return get_type_opt.value();
323 }
324
325 if (auto *ptr = llvm::dyn_cast<PointerTypeExprAST>(type_expr)) {
326 auto pointee = resolve_type_expr(ptr->pointee.get());
327 if (pointee.is_poisoned()) return pointee;
328 auto result = Type::Pointer(pointee);
329 result.is_linear = ptr->is_linear;
330 return result;
331 }
332
333 if (auto *arr = llvm::dyn_cast<ArrayTypeExprAST>(type_expr)) {
334 auto elem = resolve_type_expr(arr->element.get());
335 return elem.is_poisoned() ? elem : Type::Array(elem, arr->size) ;
336 }
337
338 if (auto *fn = llvm::dyn_cast<FunctionTypeExprAST>(type_expr)) {
339 std::vector<Type> total_types;
340 for (auto &param : fn->paramTypes) {
341 auto pt = resolve_type_expr(param.get());
342 if (pt.is_poisoned())
343 return Type::Poisoned();
344 total_types.push_back(pt);
345 }
346 auto ret = resolve_type_expr(fn->returnType.get());
347 if (ret.is_poisoned())
348 return Type::Poisoned();
349 total_types.push_back(ret);
350 return Type::Function(std::move(total_types));
351 }
352
353 if (auto *tup = llvm::dyn_cast<TupleTypeExprAST>(type_expr)) {
354 std::vector<Type> elem_types;
355 for (auto &et : tup->element_types) {
356 auto resolved = resolve_type_expr(et.get());
357 if (resolved.is_poisoned())
358 return Type::Poisoned();
359 elem_types.push_back(resolved);
360 }
361 return Type::Tuple(std::move(elem_types));
362 }
363
364 if (auto *gen = llvm::dyn_cast<GenericTypeExprAST>(type_expr)) {
365 auto base_mangled = gen->base_name.mangled();
366
367 // Look up the generic enum definition
368 auto it = generic_enum_defs.find(base_mangled);
369 if (it == generic_enum_defs.end()) {
370 this->add_error(type_expr->location,
371 fmt::format("'{}' is not a generic type",
372 gen->base_name.mangled()));
373 return Type::Poisoned();
374 }
375
376 auto *generic_def = it->second;
377 if (gen->type_args.size() != generic_def->type_params.size()) {
378 this->add_error(
379 type_expr->location,
380 fmt::format("Generic type '{}' expects {} type argument(s), got {}",
381 gen->base_name.mangled(),
382 generic_def->type_params.size(),
383 gen->type_args.size()));
384 return Type::Poisoned();
385 }
386
387 // Resolve type arguments
388 Monomorphizer::SubstitutionMap bindings;
389 std::string type_args = "<";
390 bool has_unresolved_type_param = false;
391 for (size_t i = 0; i < gen->type_args.size(); i++) {
392 auto resolved = resolve_type_expr(gen->type_args[i].get());
393 if (resolved.is_poisoned())
394 return Type::Poisoned();
395 if (resolved.type_kind == TypeKind::TypeParam)
396 has_unresolved_type_param = true;
397 bindings[generic_def->type_params[i]] = resolved;
398 if (i > 0) type_args += ", ";
399 type_args += resolved.to_string();
400 }
401 type_args += ">";
403 gen->base_name, type_args);
404 auto mangled = mono.mangled();
405
406 // If type args contain unresolved type params (e.g. Option<T> inside
407 // a generic function), we can't instantiate yet — return a placeholder
408 // that will be resolved when the outer function is monomorphized.
409 if (has_unresolved_type_param) {
410 // Check if the typename is already registered (from a previous pass)
411 auto existing = this->get_typename_type(mangled);
412 if (existing.has_value())
413 return existing.value();
414 // Not yet — just return the type param so it propagates
415 return Type::Poisoned();
416 }
417
418 // Already instantiated?
419 if (instantiated_enums.contains(mangled)) {
420 auto existing = this->get_typename_type(mangled);
421 if (existing.has_value())
422 return existing.value();
423 }
424
425 // Instantiate the generic enum
426 auto cloned = Monomorphizer::instantiate_enum(generic_def, mono,
427 bindings);
428 cloned->accept_vis(this);
429 instantiated_enums.insert(mangled);
430 monomorphized_enum_defs.push_back(std::move(cloned));
431
432 auto result = this->get_typename_type(mangled);
433 return result.has_value() ? result.value() : Type::Poisoned();
434 }
435
436 return Type::NonExistent();
437 }
438};
439// --- Numeric literal type inference helpers (shared across .cpp files) ---
440
444 if (t.type_kind == TypeKind::Integer)
445 return Type::I32_t();
446 if (t.type_kind == TypeKind::Flt)
447 return Type::F64_t();
448 return t;
449}
450
454inline void resolve_literal_type(ExprAST *expr, const Type &target) {
455 if (!expr || !expr->get_type().is_polymorphic_numeric())
456 return;
457
458 expr->set_type(target);
459
460 if (auto *unary = llvm::dyn_cast<UnaryNegExprAST>(expr)) {
461 resolve_literal_type(unary->operand.get(), target);
462 } else if (auto *binary = llvm::dyn_cast<BinaryExprAST>(expr)) {
463 resolve_literal_type(binary->LHS.get(), target);
464 resolve_literal_type(binary->RHS.get(), target);
465 } else if (auto *if_expr = llvm::dyn_cast<IfExprAST>(expr)) {
466 if (if_expr->thenBlockAST && !if_expr->thenBlockAST->Statements.empty()) {
467 auto *last_then = if_expr->thenBlockAST->Statements.back().get();
468 resolve_literal_type(last_then, target);
469 if_expr->thenBlockAST->set_type(target);
470 }
471 if (if_expr->elseBlockAST && !if_expr->elseBlockAST->Statements.empty()) {
472 auto *last_else = if_expr->elseBlockAST->Statements.back().get();
473 resolve_literal_type(last_else, target);
474 if_expr->elseBlockAST->set_type(target);
475 }
476 }
477 // For all other expression types (NumberExprAST, CallExprAST, IndexExprAST,
478 // etc.), the type is already set above — no children to recurse into.
479}
480
481} // namespace AST
482} // namespace sammine_lang
Defines the AST Abstract class for printing out AST Nodes.
Defined the AST Node classes (ProgramAST, StructDefAST, FuncDefAST) and a visitor interface for trave...
void resolve_literal_type(ExprAST *expr, const Type &target)
Definition BiTypeChecker.h:454
Type default_polymorphic_type(const Type &t)
Definition BiTypeChecker.h:443
A simple scoping class, doesn't differentiate between different names, like variable name,...
Defines the core Type system for Sammine.
Definition Types.h:68
Definition ASTProperties.h:40
Definition BiTypeChecker.h:22
ASTProperties & props_
Definition BiTypeChecker.h:30
std::optional< Type > try_get_callee_type(const std::string &str) const
Try to find a name in current scope + parent scopes (non-aborting).
Definition BiTypeChecker.h:105
An AST to simulate a { } code block.
Definition Ast.h:317
Definition Ast.h:485
Definition Ast.h:548
Definition Ast.h:798
Definition Ast.h:495
Definition Ast.h:378
Definition Ast.h:302
A Function Definition that has the prototype and definition in terms of a block.
Definition Ast.h:291
Definition Ast.h:667
Definition Ast.h:326
Definition Ast.h:598
Definition Ast.h:707
Definition AstBase.h:239
Definition Ast.h:197
A prototype to present "func func_name(...) -> type;".
Definition Ast.h:236
Definition AstBase.h:277
Definition Ast.h:205
Definition BiTypeChecker.h:20
Definition Ast.h:581
A variable definition: "var x = expression;" or "let (a, b) = expr;".
Definition Ast.h:420
Definition Types.h:398
Definition Types.h:139
static MonomorphizedName generic(QualifiedName base, std::string type_args)
Generic function/enum: identity<i32>, Option<i32>.
Definition MonomorphizedName.h:21