1 module samples.fibonacci;
2 
3 import std.conv : to;
4 import std.stdio : writefln, writeln;
5 import std.string : toStringz, fromStringz;
6 
7 import llvm;
8 
9 int main(string[] args)
10 {
11 	char* error;
12 
13 	LLVMInitializeNativeTarget();
14 	LLVMInitializeNativeAsmPrinter();
15 	LLVMInitializeNativeAsmParser();
16 
17 	auto _module = LLVMModuleCreateWithName("fibonacci".toStringz());
18 	auto f_args = [ LLVMInt32Type() ];
19 	auto f = LLVMAddFunction(
20 		_module,
21 		"fib",
22 		LLVMFunctionType(LLVMInt32Type(), f_args.ptr, 1, cast(LLVMBool) false));
23 	LLVMSetFunctionCallConv(f, LLVMCCallConv);
24 
25 	auto n = LLVMGetParam(f, 0);
26 
27 	auto entry = LLVMAppendBasicBlock(f, "entry".toStringz());
28 	auto case_base0 = LLVMAppendBasicBlock(f, "case_base0".toStringz());
29 	auto case_base1 = LLVMAppendBasicBlock(f, "case_base1".toStringz());
30 	auto case_default = LLVMAppendBasicBlock(f, "case_default".toStringz());
31 	auto end = LLVMAppendBasicBlock(f, "end".toStringz());
32 	auto builder = LLVMCreateBuilder();
33 
34 	/+ Entry basic block +/
35 	LLVMPositionBuilderAtEnd(builder, entry);
36 	auto Switch = LLVMBuildSwitch(
37 		builder,
38 		n,
39 		case_default,
40 		2);
41 	LLVMAddCase(Switch, LLVMConstInt(LLVMInt32Type(), 0, cast(LLVMBool) false), case_base0);
42 	LLVMAddCase(Switch, LLVMConstInt(LLVMInt32Type(), 1, cast(LLVMBool) false), case_base1);
43 
44 	/+ Basic block for n = 0: fib(n) = 0 +/
45 	LLVMPositionBuilderAtEnd(builder, case_base0);
46 	auto res_base0 = LLVMConstInt(LLVMInt32Type(), 0, cast(LLVMBool) false);
47 	LLVMBuildBr(builder, end);
48 
49 	/+ Basic block for n = 1: fib(n) = 1 +/
50 	LLVMPositionBuilderAtEnd(builder, case_base1);
51 	auto res_base1 = LLVMConstInt(LLVMInt32Type(), 1, cast(LLVMBool) false);
52 	LLVMBuildBr(builder, end);
53 
54 	/+ Basic block for n > 1: fib(n) = fib(n - 1) + fib(n - 2) +/
55 	LLVMPositionBuilderAtEnd(builder, case_default);
56 
57 	auto n_minus_1 = LLVMBuildSub(
58 		builder,
59 		n,
60 		LLVMConstInt(LLVMInt32Type(), 1, cast(LLVMBool) false),
61 		"n - 1".toStringz());
62 	auto call_f_1_args = [ n_minus_1 ];
63 	auto call_f_1 = LLVMBuildCall(builder, f, call_f_1_args.ptr, 1, "fib(n - 1)".toStringz());
64 
65 	auto n_minus_2 = LLVMBuildSub(
66 		builder,
67 		n,
68 		LLVMConstInt(LLVMInt32Type(), 2, cast(LLVMBool) false),
69 		"n - 2".toStringz());
70 	auto call_f_2_args = [ n_minus_2 ];
71 	auto call_f_2 = LLVMBuildCall(builder, f, call_f_2_args.ptr, 1, "fib(n - 2)".toStringz());
72 
73 	auto res_default = LLVMBuildAdd(builder, call_f_1, call_f_2, "fib(n - 1) + fib(n - 2)".toStringz());
74 	LLVMBuildBr(builder, end);
75 
76 	/+ Basic block for collecting the result +/
77 	LLVMPositionBuilderAtEnd(builder, end);
78 	auto res = LLVMBuildPhi(builder, LLVMInt32Type(), "result".toStringz());
79 	auto phi_vals = [ res_base0, res_base1, res_default ];
80 	auto phi_blocks = [ case_base0, case_base1, case_default ];
81 	LLVMAddIncoming(res, phi_vals.ptr, phi_blocks.ptr, 3);
82 	LLVMBuildRet(builder, res);
83 
84 	LLVMVerifyModule(_module, LLVMAbortProcessAction, &error);
85 	LLVMDisposeMessage(error);
86 
87 	LLVMExecutionEngineRef engine;
88 	error = null;
89 
90 	version(Windows)
91 	{
92 		/+ On Windows, we can only use the old JIT for now +/
93 		LLVMCreateJITCompilerForModule(&engine, _module, 2, &error);
94 	}
95 	else
96 	{
97 		static if (LLVM_Version >= asVersion(3,3,0))
98 		{
99 			/+ On other systems we should be able to use the newer
100 			 + MCJIT instead - if we have a high enough LLVM version +/
101 			LLVMMCJITCompilerOptions options;
102 			LLVMInitializeMCJITCompilerOptions(&options, options.sizeof);
103 
104 			LLVMCreateMCJITCompilerForModule(&engine, _module, &options, options.sizeof, &error);
105 		}
106 		else
107 		{
108 			LLVMCreateJITCompilerForModule(&engine, _module, 2, &error);
109 		}
110 	}
111 
112 	if (error)
113 	{
114 		scope (exit) LLVMDisposeMessage(error);
115 		writefln("%s", error.fromStringz());
116 		return 1;
117 	}
118 
119 	auto pass = LLVMCreatePassManager();
120 	static if (LLVM_Version < asVersion(3,9,0))
121 	{
122 		LLVMAddTargetData(LLVMGetExecutionEngineTargetData(engine), pass);
123 	}
124 	LLVMAddConstantPropagationPass(pass);
125 	LLVMAddInstructionCombiningPass(pass);
126 	LLVMAddPromoteMemoryToRegisterPass(pass);
127 	LLVMAddGVNPass(pass);
128 	LLVMAddCFGSimplificationPass(pass);
129 	LLVMRunPassManager(pass, _module);
130 
131 	writefln("The following module has been generated for the fibonacci series:\n");
132 	LLVMDumpModule(_module);
133 
134 	writeln();
135 
136 	int n_exec= 10;
137 	if (args.length > 1)
138 	{
139 		n_exec = to!int(args[1]);
140 	}
141 	else
142 	{
143 		writefln("; Argument for fib missing on command line, using default:  \"%d\"", n_exec);
144 	}
145 
146 	auto exec_args = [ LLVMCreateGenericValueOfInt(LLVMInt32Type(), n_exec, cast(LLVMBool) 0) ];
147 	writefln("; Running (jit-compiled) fib(%d)...", n_exec);
148 	auto exec_res = LLVMRunFunction(engine, f, 1, exec_args.ptr);
149 	writefln("; fib(%d) = %d", n_exec, LLVMGenericValueToInt(exec_res, 0));
150 
151 	LLVMDisposePassManager(pass);
152 	LLVMDisposeBuilder(builder);
153 	LLVMDisposeExecutionEngine(engine);
154 	return 0;
155 }