Merge branch '16-tail-recursion-optimization' into 'master'

Resolve "Tail Recursion Optimization"

Closes #16

See merge request mkais001/klang!18
This commit is contained in:
Dennis Kaiser
2020-03-09 22:59:38 +01:00
7 changed files with 214 additions and 0 deletions

View File

@@ -112,6 +112,9 @@ function forExample(end: int): int {
```
## Tail Call Optimized
Recursive tail calls are optimized at compile time.
## Statically typed
KLang statically verifies the integrity of your code. These checks include:
- Type checking
@@ -123,6 +126,7 @@ KLang statically verifies the integrity of your code. These checks include:
### Data Types
- Integer "int"
- Boolean "bool"
- Floats "float"
### Examples
You can declare types for parameters, return values and variables

View File

@@ -18,6 +18,7 @@ public class ContextAnalysis extends KlangBaseVisitor<Node> {
Map<String, FunctionInformation> funcs;
Map<String, StructDefinition> structs;
Type currentDeclaredReturnType;
String currentFunctionDefinitionName;
private void checkNumeric(Node lhs, Node rhs, int line, int col) {
if (!lhs.type.isNumericType() || !rhs.type.isNumericType()) {
@@ -246,6 +247,16 @@ public class ContextAnalysis extends KlangBaseVisitor<Node> {
public Node visitReturn_statement(KlangParser.Return_statementContext ctx) {
Expression expression = (Expression) this.visit(ctx.expression());
ReturnStatement result = new ReturnStatement(expression);
// Check if this expression is a tail recursion
if (expression instanceof FunctionCall) {
var funCall = (FunctionCall) expression;
if (funCall.name.equals(this.currentFunctionDefinitionName)) {
// Flag this function call
funCall.isTailRecursive = true;
}
}
result.type = expression.type;
result.line = ctx.start.getLine();
result.col = ctx.start.getCharPositionInLine();
@@ -720,6 +731,7 @@ public class ContextAnalysis extends KlangBaseVisitor<Node> {
int col = ctx.start.getCharPositionInLine();
Type returnType = Type.getByName(ctx.returnType.type().getText());
this.currentDeclaredReturnType = returnType;
this.currentFunctionDefinitionName = name;
if (!returnType.isPrimitiveType() && this.structs.get(returnType.getName()) == null) {
String error = "Type " + returnType.getName() + " not defined.";

View File

@@ -6,6 +6,7 @@ public class FunctionCall extends Expression {
public String name;
public Expression[] arguments;
public boolean isTailRecursive = false;
public FunctionCall(String name, Expression[] arguments) {
this.name = name;

View File

@@ -113,6 +113,8 @@ public class GenASM implements Visitor<Void> {
String[] registers = { "%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9" };
String[] floatRegisters = { "%xmm0", "%xmm1", "%xmm2", "%xmm3", "%xmm4", "%xmm5", "%xmm6", "%xmm7" };
private int lCount = 0; // Invariante: lCount ist benutzt
private int currentFunctionStartLabel = 0;
private Parameter[] currentFunctionParams;
private void intToFloat(String src, String dst) {
this.ex.write(" cvtsi2sd " + src + ", " + dst + "\n");
@@ -566,6 +568,11 @@ public class GenASM implements Visitor<Void> {
if (e.expression != null) {
e.expression.welcome(this);
int offset = this.env.get(e.name);
if (e.expression.type.equals(Type.getFloatType())) {
this.ex.write(" movq %xmm0, %rax\n");
}
this.ex.write(" movq %rax, " + offset + "(%rbp)\n");
}
return null;
@@ -606,11 +613,15 @@ public class GenASM implements Visitor<Void> {
@Override
public Void visit(FunctionDefinition e) {
int lblStart = ++lCount;
this.currentFunctionStartLabel = lblStart;
this.currentFunctionParams = e.parameters;
this.ex.write(".globl " + e.name + "\n");
this.ex.write(".type " + e.name + ", @function\n");
this.ex.write(e.name + ":\n");
this.ex.write(" pushq %rbp\n");
this.ex.write(" movq %rsp, %rbp\n");
this.ex.write(".L" + lblStart + ":\n");
// hole die anzahl der lokalen variablen
this.vars = new TreeSet<String>();
@@ -682,6 +693,25 @@ public class GenASM implements Visitor<Void> {
@Override
public Void visit(FunctionCall e) {
if (e.isTailRecursive) {
// Visit the arguments and move them into the location of the corresponding local var
for(int i = 0; i < e.arguments.length; i++) {
e.arguments[i].welcome(this);
int offset = this.env.get(this.currentFunctionParams[i].name);
if (e.arguments[i].type.equals(Type.getFloatType())) {
this.ex.write(" movq %xmm0, %rax\n");
}
this.ex.write(" movq %rax, " + offset + "(%rbp)\n");
}
this.ex.write(" jmp .L" + this.currentFunctionStartLabel + "\n");
return null;
}
if (e.arguments.length > 0) {
// Mapping arguments index -> xmm registers index
int[] xmmIdxs = new int[this.floatRegisters.length];

View File

@@ -92,4 +92,28 @@ int runFunctionCallTests () {
argumentTest("fgetMix8(...args)", 8, fgetMix8());
argumentTest_f("fgetMix9(...args)", 9.0, fgetMix9());
argumentTest("fgetMix10(...args)", 10, fgetMix10());
printf("\nTail Call Tests \n");
// Checks that tails calls are properly invoked
argumentTest("arg1Tail(...args)", 1, arg1Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg2Tail(...args)", 2, arg2Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg3Tail(...args)", 3, arg3Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg4Tail(...args)", 4, arg4Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg5Tail(...args)", 5, arg5Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg6Tail(...args)", 6, arg6Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg7Tail(...args)", 7, arg7Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg8Tail(...args)", 8, arg8Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg9Tail(...args)", 9, arg9Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
argumentTest("arg10Tail(...args)", 10, arg10Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10));
// Checks that parameters are correctly passed from klang to functions
argumentTest("get1Tail(...args)", 1, get1Tail(10));
argumentTest("get2Tail(...args)", 2, get2Tail(10));
argumentTest("get3Tail(...args)", 3, get3Tail(10));
argumentTest("get4Tail(...args)", 4, get4Tail(10));
argumentTest("get5Tail(...args)", 5, get5Tail(10));
argumentTest("get6Tail(...args)", 6, get6Tail(10));
argumentTest("get7Tail(...args)", 7, get7Tail(10));
argumentTest("get8Tail(...args)", 8, get8Tail(10));
argumentTest("get9Tail(...args)", 9, get9Tail(10));
argumentTest("get10Tail(...args)", 10, get10Tail(10));
}

View File

@@ -20,6 +20,28 @@ long get8();
long get9();
long get10();
long arg1Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg2Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg3Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg4Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg5Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg6Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg7Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg8Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg9Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long arg10Tail(long a, long b, long c, long d, long e, long f, long g, long h, long i, long j, long count);
long get1Tail(long count);
long get2Tail(long count);
long get3Tail(long count);
long get4Tail(long count);
long get5Tail(long count);
long get6Tail(long count);
long get7Tail(long count);
long get8Tail(long count);
long get9Tail(long count);
long get10Tail(long count);
double farg1(double a, double b, double c, double d, double e, double f, double g, double h, double i, double j);
double farg2(double a, double b, double c, double d, double e, double f, double g, double h, double i, double j);
double farg3(double a, double b, double c, double d, double e, double f, double g, double h, double i, double j);

View File

@@ -106,6 +106,127 @@ function get10(): int {
return arg10(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
}
// TAIL CALL
function arg1Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return a;
}
return arg1Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get1Tail(count: int): int {
return arg1Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg2Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return b;
}
return arg2Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get2Tail(count: int): int {
return arg2Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg3Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return c;
}
return arg3Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get3Tail(count: int): int {
return arg3Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg4Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return d;
}
return arg4Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get4Tail(count: int): int {
return arg4Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg5Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return e;
}
return arg5Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get5Tail(count: int): int {
return arg5Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg6Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return f;
}
return arg6Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get6Tail(count: int): int {
return arg6Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg7Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return g;
}
return arg7Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get7Tail(count: int): int {
return arg7Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg8Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return h;
}
return arg8Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get8Tail(count: int): int {
return arg8Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg9Tail(a: int, b: int, c: int,d: int,e: int,f: int,g: int, h: int,i: int,j: int, count: int): int {
if (count <= 0) {
return i;
}
return arg9Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get9Tail(count: int): int {
return arg9Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
function arg10Tail(a: int, b: int, c: int, d: int, e: int, f: int, g: int, h: int, i: int, j: int, count: int): int {
if (count <= 0) {
return j;
}
return arg10Tail(a, b, c, d, e, f, g, h, i, j, count - 1);
}
function get10Tail(count: int): int {
return arg10Tail(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10);
}
// FLOATS
function farg1(a: float, b: float, c: float, d: float, e: float, f: float, g: float, h: float, i: float ,j: float): float {