1 module samples.fibonacci;
2 
3 import std.conv : to;
4 import std.stdio : writefln, writeln;
5 import std.string : toStringz, fromStringz;
6 
7 import llvm;
8 
9 immutable useMCJIT = {
10     // MCJIT does not work on Windows
11     version(Windows) { return false; }
12     else {
13         // Use MCJIT only if LLVMGetFunctionAddress is available,
14         // as LLVMRunFunction does not work reliably with it.
15         static if (LLVM_Version >= asVersion(3,6,0)) { return true; }
16         else { return false; }
17     }
18 }();
19 
20 void initJIT(ref LLVMExecutionEngineRef engine, LLVMModuleRef genModule)
21 {
22     char* error;
23 
24     static if (useMCJIT) {
25         LLVMMCJITCompilerOptions options;
26         LLVMInitializeMCJITCompilerOptions(&options, options.sizeof);
27 
28         LLVMCreateMCJITCompilerForModule(&engine, genModule, &options, options.sizeof, &error);
29     } else {
30         LLVMCreateJITCompilerForModule(&engine, genModule, 2, &error);
31     }
32 
33     if (error)
34     {
35         scope (exit) LLVMDisposeMessage(error);
36         throw new Exception(error.fromStringz().idup);
37     }
38 }
39 
40 int main(string[] args)
41 {
42     char* error;
43 
44     LLVMInitializeNativeTarget();
45     LLVMInitializeNativeAsmPrinter();
46     LLVMInitializeNativeAsmParser();
47 
48     auto genModule = LLVMModuleCreateWithName("fibonacci".toStringz());
49     auto genFibParams = [ LLVMInt32Type() ];
50     auto genFib = LLVMAddFunction(
51         genModule,
52         "fib",
53         LLVMFunctionType(LLVMInt32Type(), genFibParams.ptr, 1, cast(LLVMBool) false));
54     LLVMSetFunctionCallConv(genFib, LLVMCCallConv);
55 
56     auto genN = LLVMGetParam(genFib, 0);
57 
58     auto genEntryBlk = LLVMAppendBasicBlock(genFib, "entry".toStringz());
59     auto genAnchor0Blk = LLVMAppendBasicBlock(genFib, "anchor0".toStringz());
60     auto genAnchor1Blk = LLVMAppendBasicBlock(genFib, "anchor1".toStringz());
61     auto genRecurseBlk = LLVMAppendBasicBlock(genFib, "recurse".toStringz());
62     auto end = LLVMAppendBasicBlock(genFib, "end".toStringz());
63 
64     auto builder = LLVMCreateBuilder();
65 
66     /+ Entry block +/
67     LLVMPositionBuilderAtEnd(builder, genEntryBlk);
68     auto fibSwitch = LLVMBuildSwitch(
69         builder,
70         genN,
71         genRecurseBlk,
72         2);
73     LLVMAddCase(fibSwitch, LLVMConstInt(LLVMInt32Type(), 0, cast(LLVMBool) false), genAnchor0Blk);
74     LLVMAddCase(fibSwitch, LLVMConstInt(LLVMInt32Type(), 1, cast(LLVMBool) false), genAnchor1Blk);
75 
76     /+ Block for n = 0: fib(n) = 0 +/
77     LLVMPositionBuilderAtEnd(builder, genAnchor0Blk);
78     auto genAnchor0Result = LLVMConstInt(LLVMInt32Type(), 0, cast(LLVMBool) false);
79     LLVMBuildBr(builder, end);
80 
81     /+ Block for n = 1: fib(n) = 1 +/
82     LLVMPositionBuilderAtEnd(builder, genAnchor1Blk);
83     auto genAnchor1Result = LLVMConstInt(LLVMInt32Type(), 1, cast(LLVMBool) false);
84     LLVMBuildBr(builder, end);
85 
86     /+ Block for n > 1: fib(n) = fib(n - 1) + fib(n - 2) +/
87     LLVMPositionBuilderAtEnd(builder, genRecurseBlk);
88 
89     auto genNMinus1 = LLVMBuildSub(
90         builder,
91         genN,
92         LLVMConstInt(LLVMInt32Type(), 1, cast(LLVMBool) false),
93         "n - 1".toStringz());
94     auto genCallFibNMinus1 = LLVMBuildCall(builder, genFib, [genNMinus1].ptr, 1, "fib(n - 1)".toStringz());
95 
96     auto genNMinus2 = LLVMBuildSub(
97         builder,
98         genN,
99         LLVMConstInt(LLVMInt32Type(), 2, cast(LLVMBool) false),
100         "n - 2".toStringz());
101     auto genCallFibNMinus2 = LLVMBuildCall(builder, genFib, [genNMinus2].ptr, 1, "fib(n - 2)".toStringz());
102 
103     auto genRecurseResult = LLVMBuildAdd(builder, genCallFibNMinus1, genCallFibNMinus2, "fib(n - 1) + fib(n - 2)".toStringz());
104     LLVMBuildBr(builder, end);
105 
106     /+ Block for collecting the final result +/
107     LLVMPositionBuilderAtEnd(builder, end);
108     auto genFinalResult = LLVMBuildPhi(builder, LLVMInt32Type(), "result".toStringz());
109     auto phiValues = [ genAnchor0Result, genAnchor1Result, genRecurseResult ];
110     auto phiBlocks = [ genAnchor0Blk, genAnchor1Blk, genRecurseBlk ];
111     LLVMAddIncoming(genFinalResult, phiValues.ptr, phiBlocks.ptr, 3);
112     LLVMBuildRet(builder, genFinalResult);
113 
114     LLVMVerifyModule(genModule, LLVMAbortProcessAction, &error);
115     LLVMDisposeMessage(error);
116 
117     LLVMExecutionEngineRef engine;
118     error = null;
119 
120     initJIT(engine, genModule);
121 
122     auto pass = LLVMCreatePassManager();
123     static if (LLVM_Version < asVersion(3,9,0))
124     {
125         LLVMAddTargetData(LLVMGetExecutionEngineTargetData(engine), pass);
126     }
127     LLVMAddConstantPropagationPass(pass);
128     LLVMAddInstructionCombiningPass(pass);
129     LLVMAddPromoteMemoryToRegisterPass(pass);
130     LLVMAddGVNPass(pass);
131     LLVMAddCFGSimplificationPass(pass);
132     LLVMRunPassManager(pass, genModule);
133 
134     writefln("The following module has been generated for the fibonacci series:\n");
135     LLVMDumpModule(genModule);
136 
137     writeln();
138 
139     int n = 10;
140     if (args.length > 1)
141     {
142         n = to!int(args[1]);
143     }
144     else
145     {
146         writefln("; Argument for fib missing on command line, using default:  \"%d\"", n);
147     }
148 
149     int fib(int n)
150     {
151         static if (useMCJIT) {
152             alias Fib = extern (C) int function(int);
153             auto fib = cast(Fib) LLVMGetFunctionAddress(engine, "fib".toStringz());
154             return fib(n);
155         } else {
156             auto args = [ LLVMCreateGenericValueOfInt(LLVMInt32Type(), n, cast(LLVMBool) 0) ];
157             return LLVMGenericValueToInt(LLVMRunFunction(engine, f, 1, args.ptr), 0);
158         }
159     }
160 
161     writefln("; Running (jit-compiled) fib(%d)...", n);
162     writefln("; fib(%d) = %d", n, fib(n));
163 
164     LLVMDisposePassManager(pass);
165     LLVMDisposeBuilder(builder);
166     LLVMDisposeExecutionEngine(engine);
167     return 0;
168 }