diff --git a/README.md b/README.md index 0174bb5..4664533 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,9 @@ function forExample(end: int): int { ``` +## Tail Call Optimized +Recursive tail calls are optimized at compile time. + ## Statically typed KLang statically verifies the integrity of your code. These checks include: - Type checking @@ -123,6 +126,7 @@ KLang statically verifies the integrity of your code. These checks include: ### Data Types - Integer "int" - Boolean "bool" +- Floats "float" ### Examples You can declare types for parameters, return values and variables diff --git a/src/main/java/de/hsrm/compiler/Klang/ContextAnalysis.java b/src/main/java/de/hsrm/compiler/Klang/ContextAnalysis.java index 61275a6..33e8b6a 100644 --- a/src/main/java/de/hsrm/compiler/Klang/ContextAnalysis.java +++ b/src/main/java/de/hsrm/compiler/Klang/ContextAnalysis.java @@ -18,6 +18,7 @@ public class ContextAnalysis extends KlangBaseVisitor { Map funcs; Map structs; Type currentDeclaredReturnType; + String currentFunctionDefinitionName; private void checkNumeric(Node lhs, Node rhs, int line, int col) { if (!lhs.type.isNumericType() || !rhs.type.isNumericType()) { @@ -246,6 +247,16 @@ public class ContextAnalysis extends KlangBaseVisitor { public Node visitReturn_statement(KlangParser.Return_statementContext ctx) { Expression expression = (Expression) this.visit(ctx.expression()); ReturnStatement result = new ReturnStatement(expression); + + // Check if this expression is a tail recursion + if (expression instanceof FunctionCall) { + var funCall = (FunctionCall) expression; + if (funCall.name.equals(this.currentFunctionDefinitionName)) { + // Flag this function call + funCall.isTailRecursive = true; + } + } + result.type = expression.type; result.line = ctx.start.getLine(); result.col = ctx.start.getCharPositionInLine(); @@ -720,6 +731,7 @@ public class ContextAnalysis extends KlangBaseVisitor { int col = ctx.start.getCharPositionInLine(); Type returnType = Type.getByName(ctx.returnType.type().getText()); this.currentDeclaredReturnType = returnType; + this.currentFunctionDefinitionName = name; if (!returnType.isPrimitiveType() && this.structs.get(returnType.getName()) == null) { String error = "Type " + returnType.getName() + " not defined."; diff --git a/src/main/java/de/hsrm/compiler/Klang/nodes/expressions/FunctionCall.java b/src/main/java/de/hsrm/compiler/Klang/nodes/expressions/FunctionCall.java index feec2af..2375701 100644 --- a/src/main/java/de/hsrm/compiler/Klang/nodes/expressions/FunctionCall.java +++ b/src/main/java/de/hsrm/compiler/Klang/nodes/expressions/FunctionCall.java @@ -6,6 +6,7 @@ public class FunctionCall extends Expression { public String name; public Expression[] arguments; + public boolean isTailRecursive = false; public FunctionCall(String name, Expression[] arguments) { this.name = name; diff --git a/src/main/java/de/hsrm/compiler/Klang/visitors/GenASM.java b/src/main/java/de/hsrm/compiler/Klang/visitors/GenASM.java index 1bc49dc..126845d 100644 --- a/src/main/java/de/hsrm/compiler/Klang/visitors/GenASM.java +++ b/src/main/java/de/hsrm/compiler/Klang/visitors/GenASM.java @@ -113,6 +113,8 @@ public class GenASM implements Visitor { String[] registers = { "%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9" }; String[] floatRegisters = { "%xmm0", "%xmm1", "%xmm2", "%xmm3", "%xmm4", "%xmm5", "%xmm6", "%xmm7" }; private int lCount = 0; // Invariante: lCount ist benutzt + private int currentFunctionStartLabel = 0; + private Parameter[] currentFunctionParams; private void intToFloat(String src, String dst) { this.ex.write(" cvtsi2sd " + src + ", " + dst + "\n"); @@ -566,6 +568,11 @@ public class GenASM implements Visitor { if (e.expression != null) { e.expression.welcome(this); int offset = this.env.get(e.name); + + if (e.expression.type.equals(Type.getFloatType())) { + this.ex.write(" movq %xmm0, %rax\n"); + } + this.ex.write(" movq %rax, " + offset + "(%rbp)\n"); } return null; @@ -606,11 +613,15 @@ public class GenASM implements Visitor { @Override public Void visit(FunctionDefinition e) { + int lblStart = ++lCount; + this.currentFunctionStartLabel = lblStart; + this.currentFunctionParams = e.parameters; this.ex.write(".globl " + e.name + "\n"); this.ex.write(".type " + e.name + ", @function\n"); this.ex.write(e.name + ":\n"); this.ex.write(" pushq %rbp\n"); this.ex.write(" movq %rsp, %rbp\n"); + this.ex.write(".L" + lblStart + ":\n"); // hole die anzahl der lokalen variablen this.vars = new TreeSet(); @@ -682,6 +693,25 @@ public class GenASM implements Visitor { @Override public Void visit(FunctionCall e) { + if (e.isTailRecursive) { + + // Visit the arguments and move them into the location of the corresponding local var + for(int i = 0; i < e.arguments.length; i++) { + e.arguments[i].welcome(this); + int offset = this.env.get(this.currentFunctionParams[i].name); + + if (e.arguments[i].type.equals(Type.getFloatType())) { + this.ex.write(" movq %xmm0, %rax\n"); + } + + this.ex.write(" movq %rax, " + offset + "(%rbp)\n"); + } + + this.ex.write(" jmp .L" + this.currentFunctionStartLabel + "\n"); + return null; + } + + if (e.arguments.length > 0) { // Mapping arguments index -> xmm registers index int[] xmmIdxs = new int[this.floatRegisters.length]; diff --git a/src/test/functionCall/functionCall.c b/src/test/functionCall/functionCall.c index c438ddf..73f23bb 100644 --- a/src/test/functionCall/functionCall.c +++ b/src/test/functionCall/functionCall.c @@ -92,4 +92,28 @@ int runFunctionCallTests () { argumentTest("fgetMix8(...args)", 8, fgetMix8()); argumentTest_f("fgetMix9(...args)", 9.0, fgetMix9()); argumentTest("fgetMix10(...args)", 10, fgetMix10()); + + printf("\nTail Call Tests \n"); + // Checks that tails calls are properly invoked + argumentTest("arg1Tail(...args)", 1, arg1Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg2Tail(...args)", 2, arg2Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg3Tail(...args)", 3, arg3Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg4Tail(...args)", 4, arg4Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg5Tail(...args)", 5, arg5Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg6Tail(...args)", 6, arg6Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg7Tail(...args)", 7, arg7Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg8Tail(...args)", 8, arg8Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg9Tail(...args)", 9, arg9Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + argumentTest("arg10Tail(...args)", 10, arg10Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10)); + // Checks that parameters are correctly passed from klang to functions + argumentTest("get1Tail(...args)", 1, get1Tail(10)); + argumentTest("get2Tail(...args)", 2, get2Tail(10)); + argumentTest("get3Tail(...args)", 3, get3Tail(10)); + argumentTest("get4Tail(...args)", 4, get4Tail(10)); + argumentTest("get5Tail(...args)", 5, get5Tail(10)); + argumentTest("get6Tail(...args)", 6, get6Tail(10)); + argumentTest("get7Tail(...args)", 7, get7Tail(10)); + argumentTest("get8Tail(...args)", 8, get8Tail(10)); + argumentTest("get9Tail(...args)", 9, get9Tail(10)); + argumentTest("get10Tail(...args)", 10, get10Tail(10)); } \ No newline at end of file diff --git a/src/test/functionCall/functionCall.h b/src/test/functionCall/functionCall.h index 7aac291..5238453 100644 --- a/src/test/functionCall/functionCall.h +++ b/src/test/functionCall/functionCall.h @@ -20,6 +20,28 @@ long get8(); long get9(); long get10(); +long arg1Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg2Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg3Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg4Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg5Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg6Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg7Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg8Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg9Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); +long arg10Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count); + +long get1Tail(long count); +long get2Tail(long count); +long get3Tail(long count); +long get4Tail(long count); +long get5Tail(long count); +long get6Tail(long count); +long get7Tail(long count); +long get8Tail(long count); +long get9Tail(long count); +long get10Tail(long count); + double farg1(double a, double b, double c, double d, double e, double f, double g, double h, double i, double j); double farg2(double a, double b, double c, double d, double e, double f, double g, double h, double i, double j); double farg3(double a, double b, double c, double d, double e, double f, double g, double h, double i, double j); diff --git a/src/test/test.k b/src/test/test.k index 395afa2..5f24c1d 100644 --- a/src/test/test.k +++ b/src/test/test.k @@ -106,6 +106,127 @@ function get10(): int { return arg10(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); } +// TAIL CALL +function arg1Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return a; + } + + return arg1Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get1Tail(count: int): int { + return arg1Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg2Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return b; + } + + return arg2Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get2Tail(count: int): int { + return arg2Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg3Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return c; + } + + return arg3Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get3Tail(count: int): int { + return arg3Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg4Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return d; + } + + return arg4Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get4Tail(count: int): int { + return arg4Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg5Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return e; + } + + return arg5Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get5Tail(count: int): int { + return arg5Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg6Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return f; + } + + return arg6Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get6Tail(count: int): int { + return arg6Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg7Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return g; + } + + return arg7Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get7Tail(count: int): int { + return arg7Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg8Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return h; + } + + return arg8Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get8Tail(count: int): int { + return arg8Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg9Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int { + if (count <= 0) { + return i; + } + + return arg9Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get9Tail(count: int): int { + return arg9Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + +function arg10Tail(a: int, b: int, c: int, d: int, e: int, f: int, g: int, h: int, i: int, j: int, count: int): int { + if (count <= 0) { + return j; + } + + return arg10Tail(a, b, c, d, e, f, g, h, i, j, count - 1); +} + +function get10Tail(count: int): int { + return arg10Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10); +} + // FLOATS function farg1(a: float, b: float, c: float, d: float, e: float, f: float, g: float, h: float, i: float ,j: float): float {