1 module vayne.vm;
2 
3 
4 import std.conv;
5 import std.datetime;
6 import std.range;
7 import std.stdio;
8 import std.string;
9 import std.traits;
10 
11 
12 import vayne.op;
13 import vayne.value;
14 import vayne.source.source;
15 
16 public import vayne.value : Value;
17 
18 
19 enum VMOptions : uint {
20 	PrintOutput		= 1 << 0,
21 	PrintOpCodes	= 1 << 1,
22 	DebugMode		= PrintOutput | PrintOpCodes,
23 	Default			= 0,
24 }
25 
26 
27 class VMException : Exception {
28 	this(string msg, string source, size_t line) {
29 		super(msg, source, line);
30 	}
31 }
32 
33 
34 private template isWriterObject(T) {
35 	enum isWriterObject = isOutputRange!(T, char) || __traits(compiles, T.init.write(""));
36 }
37 
38 
39 struct VM(uint options = VMOptions.Default, uint registerCountMax = 0) {
40 	static struct Error {
41 		string msg;
42 		string source;
43 		size_t line;
44 	}
45 	alias ErrorHandler = void delegate(Error);
46 	alias Globals = Value[string];
47 
48 	void load(size_t registers, Value[] constants, const(Instr)[] instrs, const(SourceLoc)[] locs, const(string)[] sources) {
49 		instrs_ = instrs;
50 		locs_ = locs;
51 		sources_ = sources;
52 		consts_ = constants;
53 
54 		static if (registerCountMax > 0) {
55 			if (registers > registerCountMax)
56 				throw new Exception(format("not enough pre-allocated registers %d > %d", registers, registerCountMax));
57 		} else {
58 			regs_.length = registers;
59 		}
60 	}
61 
62 	void load(size_t registers, size_t constants, const(Instr)[] instrs, const(SourceLoc)[] locs, const(string)[] sources) {
63 		instrs_ = instrs;
64 		locs_ = locs;
65 		sources_ = sources;
66 
67 		consts_.length = constants;
68 
69 		static if (registerCountMax > 0) {
70 			if (registers > registerCountMax)
71 				throw new Exception(format("not enough pre-allocated registers %d > %d", registers, registerCountMax));
72 		} else {
73 			regs_.length = registers;
74 		}
75 	}
76 
77 	void bindConst(T)(size_t index, ref T value) if (is(T == struct)) {
78 		consts_[index] = Value(value);
79 	}
80 
81 	void bindConst(T)(size_t index, in T value)  if (!is(T == struct)) {
82 		consts_[index] = Value(value);
83 	}
84 
85 	void setGlobals(Globals globals) {
86 		globals_ = globals;
87 	}
88 
89 	void bindGlobals(Globals globals) {
90 		foreach (k, v; globals)
91 			globals_[k] = v;
92 	}
93 
94 	void bindGlobal(T)(string name, ref T value) if (is(T == struct)) {
95 		globals_[name] = Value(value);
96 	}
97 
98 	void bindGlobal(T)(string name, in T value)  if (!is(T == struct)) {
99 		globals_[name] = Value(value);
100 	}
101 
102 	@property void errorHandler(ErrorHandler handler) {
103 		errorHandler_ = handler;
104 	}
105 
106 	@property ErrorHandler errorHandler() const {
107 		return errorHandler_;
108 	}
109 
110 	void execute(T)(ref T output, Globals globals) if (isWriterObject!T) {
111 		globals_ = globals;
112 
113 		execute(output);
114 	}
115 
116 	void execute(T)(ref T output) if (isWriterObject!T) {
117 		Instr instr = instrs_[0];
118 		size_t ip;
119 
120 		ref auto getArgV(size_t Arg)() if (Arg <= 3) {
121 			if (instr.argConst!Arg) {
122 				return consts_.ptr[instr.arg!Arg];
123 			} else {
124 				return regs_.ptr[instr.arg!Arg];
125 			}
126 		}
127 
128 		ref auto argBinaryOp(size_t A, string op, size_t B)() {
129 			return getArgV!A.binaryOp!op(getArgV!B);
130 		}
131 
132 		ref auto argCompareOp(size_t A, string op, size_t B)() {
133 			return Value(getArgV!A.compareOp!op(getArgV!B));
134 		}
135 
136 		while (true) {
137 			const op = instr.op;
138 
139 			static if (options & VMOptions.PrintOpCodes) {
140 				writeln(op);
141 			}
142 
143 			try {
144 				Lswitch: final switch (op) with (OpCode) {
145 				case Output:
146 					static if (isOutputRange!(T, char)) {
147 						output.put(getArgV!0.get!string);
148 					} else {
149 						output.write(getArgV!0.get!string);
150 					}
151 
152 					static if (options & VMOptions.PrintOutput) {
153 						write(getArgV!0.get!string);
154 					}
155 					break;
156 				case Move:
157 					regs_.ptr[instr.arg!0] = getArgV!1;
158 					break;
159 				case Test:
160 					regs_.ptr[instr.arg!0] = Value(getArgV!1.get!bool);
161 					break;
162 				case Increment:
163 					regs_.ptr[instr.arg!0].unaryOp!"++";
164 					break;
165 				case Jump:
166 					ip = instr.arg!0;
167 					instr = instrs_.ptr[ip];
168 					continue;
169 				case JumpIfZero:
170 					if (getArgV!1.get!long == 0)
171 						goto case Jump;
172 					break;
173 				case JumpIfNotZero:
174 					if (getArgV!1.get!long != 0)
175 						goto case Jump;
176 					break;
177 				case Element:
178 					regs_.ptr[instr.arg!0] = getArgV!1[getArgV!2];
179 					break;
180 				case LookUp:
181 					auto name = getArgV!1;
182 					auto pout = &regs_.ptr[instr.arg!0];
183 
184 					foreach_reverse (ref s; scopes_) {
185 						if (s.has(name, pout))
186 							break Lswitch;
187 					}
188 
189 					if (name.type == Value.Type.String) {
190 						if (auto pvalue = name.get!string in globals_) {
191 							*pout = *pvalue;
192 							break Lswitch;
193 						}
194 					}
195 
196 					*pout = Value.init;
197 					break;
198 				case Call:
199 					auto func = getArgV!1;
200 					func.call(regs_.ptr[instr.arg!0], regs_.ptr[instr.arg!2..instr.arg!2 + instr.arg!3]);
201 					break;
202 				case DispatchCall:
203 					auto func = getArgV!1;
204 					if (dispatchArg_) {
205 						func.call(regs_.ptr[instr.arg!0], regs_.ptr[instr.arg!2..instr.arg!2 + instr.arg!3]);
206 						dispatchArg_ = false;
207 					} else {
208 						func.call(regs_.ptr[instr.arg!0], regs_.ptr[1 + instr.arg!2..instr.arg!2 + instr.arg!3]);
209 					}
210 					break;
211 				case Dispatch:
212 					assert(!dispatchArg_);
213 
214 					auto name = getArgV!2;
215 					auto pout = &regs_.ptr[instr.arg!0];
216 
217 					auto obj = getArgV!1;
218 					switch (obj.type) with (Value.Type) {
219 					case Object:
220 					case AssocArray:
221 						if (getArgV!1.has(name, pout))
222 							break Lswitch;
223 						break;
224 					default:
225 						break;
226 					}
227 
228 					if (name.type == Value.Type.String) {
229 						if (auto pvalue = name.get!string in globals_) {
230 							if (pvalue.type == Value.Type.Function) {
231 								dispatchArg_ = true;
232 								*pout = *pvalue;
233 								break Lswitch;
234 							}
235 						}
236 					}
237 
238 					throw new Exception(format("dispatch failed for identifier '%s'", name.get!string));
239 				case Decrement:
240 					regs_.ptr[instr.arg!0].unaryOp!"--";
241 					break;
242 				case Concat:
243 					regs_.ptr[instr.arg!0] = getArgV!1.concatOp(getArgV!2);
244 					break;
245 				case PushScope:
246 					auto scope_ = getArgV!0;
247 					switch (scope_.type) with (Value.Type) {
248 					case Object:
249 					case AssocArray:
250 						scopes_ ~= getArgV!0;
251 						break;
252 					default:
253 						throw new Exception(format("with statement expressions must be of type %s or %s, not %s", Object, AssocArray, scope_.type));
254 					}
255 					break;
256 				case PopScope:
257 					scopes_.length = scopes_.length - instr.arg!0;
258 					break;
259 				case Equal:
260 					regs_.ptr[instr.arg!0] = argCompareOp!(1, "==", 2);
261 					break;
262 				case NotEqual:
263 					regs_.ptr[instr.arg!0] = argCompareOp!(1, "!=", 2);
264 					break;
265 				case Less:
266 					regs_.ptr[instr.arg!0] = argCompareOp!(1, "<", 2);
267 					break;
268 				case LessOrEqual:
269 					regs_.ptr[instr.arg!0] = argCompareOp!(1, "<=", 2);
270 					break;
271 				case Greater:
272 					regs_.ptr[instr.arg!0] = argCompareOp!(1, ">", 2);
273 					break;
274 				case GreaterOrEqual:
275 					regs_.ptr[instr.arg!0] = argCompareOp!(1, ">=", 2);
276 					break;
277 				case Not:
278 					regs_.ptr[instr.arg!0] = Value(!regs_.ptr[instr.arg!0].get!bool);
279 					break;
280 				case And:
281 					regs_.ptr[instr.arg!0] = argCompareOp!(1, "&&", 2);
282 					break;
283 				case Or:
284 					regs_.ptr[instr.arg!0] = argCompareOp!(1, "||", 2);
285 					break;
286 				case Minus:
287 					regs_.ptr[instr.arg!0].unaryOp!"-";
288 					break;
289 				case Add:
290 					regs_.ptr[instr.arg!0] = argBinaryOp!(1, "+", 2);
291 					break;
292 				case Subtract:
293 					regs_.ptr[instr.arg!0] = argBinaryOp!(1, "-", 2);
294 					break;
295 				case Multiply:
296 					regs_.ptr[instr.arg!0] = argBinaryOp!(1, "*", 2);
297 					break;
298 				case Divide:
299 					regs_.ptr[instr.arg!0] = argBinaryOp!(1, "/", 2);
300 					break;
301 				case Remainder:
302 					regs_.ptr[instr.arg!0] = argBinaryOp!(1, "%", 2);
303 					break;
304 				case Power:
305 					regs_.ptr[instr.arg!0] = argBinaryOp!(1, "^^", 2);
306 					break;
307 				case Length:
308 					regs_.ptr[instr.arg!0] = Value(getArgV!1.length);
309 					break;
310 				case Keys:
311 					regs_.ptr[instr.arg!0] = getArgV!1.keys();
312 					break;
313 				case TestKey:
314 					regs_.ptr[instr.arg!0] = Value(getArgV!1.has(getArgV!2));
315 					break;
316 				case Slice:
317 					regs_.ptr[instr.arg!0] = getArgV!1[getArgV!2..getArgV!3];
318 					break;
319 				case Nop:
320 					break;
321 				case Halt:
322 					return;
323 				case Throw:
324 					throw new Exception(getArgV!0.toString);
325 				}
326 
327 				if (++ip >= instrs_.length)
328 					break;
329 				instr = instrs_.ptr[ip];
330 			} catch (Throwable e) {
331 				auto loc = locs_[ip];
332 				auto error = e.msg;
333 				auto source = sources_[loc.id];
334 				auto line = loc.line;
335 
336 				if (errorHandler_) {
337 					errorHandler_(Error(error, source, line));
338 					break;
339 				} else {
340 					auto rethrow = new VMException(error, source, line);
341 					rethrow.info = e.info;
342 
343 					throw rethrow;
344 				}
345 			}
346 		}
347 	}
348 
349 private:
350 	Value[] consts_;
351 
352 	static if (registerCountMax > 0) {
353 		Value[registerCountMax] regs_;
354 	} else {
355 		Value[] regs_;
356 	}
357 
358 	Globals globals_;
359 	Value[] scopes_;
360 	bool dispatchArg_;
361 
362 	const(Instr)[] instrs_;
363 	const(SourceLoc)[] locs_;
364 	const(string)[] sources_;
365 
366 	ErrorHandler errorHandler_;
367 }