1 /// Base APIs to create LSP servers quickly. Reconfigures stdin and stdout upon
2 /// importing to avoid accidental usage of the RPC channel. Changes stdin to a
3 /// null file and stdout to stderr.
4 module served.serverbase;
5 
6 import served.utils.events : EventProcessorConfig;
7 
8 import io = std.stdio;
9 
10 /// Actual stdin/stdio as used for RPC communication.
11 __gshared io.File stdin, stdout;
12 shared static this()
13 {
14 	stdin = io.stdin;
15 	stdout = io.stdout;
16 	version (Windows)
17 		io.stdin = io.File("NUL", "r");
18 	else version (Posix)
19 		io.stdin = io.File("/dev/null", "r");
20 	else
21 		io.stderr.writeln("warning: no /dev/null implementation on this OS");
22 	io.stdout = io.stderr;
23 }
24 
25 struct LanguageServerConfig
26 {
27 	int defaultPages = 20;
28 	int fiberPageSize = 4096;
29 
30 	EventProcessorConfig eventConfig;
31 
32 	/// Product name to use in error messages
33 	string productName = "serve-d";
34 
35 	/// If set to non-zero, call GC.collect every n seconds and GC.minimize
36 	/// every 5th call. Keeps track of cleaned up memory in trace logs.
37 	int gcCollectSeconds = 30;
38 }
39 
40 // dumps a performance/GC trace log to served_trace.log
41 //debug = PerfTraceLog;
42 
43 /// Utility to setup an RPC connection via stdin/stdout and route all requests
44 /// to methods defined in the given extension module.
45 ///
46 /// Params:
47 ///   ExtensionModule = a module defining the following members:
48 ///   - `members`: a compile time list of all members in all modules that should
49 ///     be introspected to be called automatically on matching RPC commands.
50 ///   - `InitializeResult initialize(InitializeParams)`: initialization method.
51 ///
52 ///   Optional:
53 ///   - `bool shutdownRequested`: a boolean that is set to true before the
54 ///     `shutdown` method handler or earlier which will terminate the RPC loop
55 ///     gracefully and wait for an `exit` notification to actually exit.
56 ///   - `@protocolMethod("shutdown") JSONValue shutdown()`: the method called
57 ///     when the client wants to shutdown the server. Can return anything,
58 ///     recommended return value is `JSONValue(null)`.
59 ///   - `parallelMain`: an optional method which is run alongside everything
60 ///     else in parallel using fibers. Should yield as much as possible when
61 ///     there is nothing to do.
62 mixin template LanguageServerRouter(alias ExtensionModule, LanguageServerConfig serverConfig = LanguageServerConfig.init)
63 {
64 	static assert(is(typeof(ExtensionModule.members)), "Missing members field in ExtensionModule " ~ ExtensionModule.stringof);
65 	static assert(is(typeof(ExtensionModule.initialize)), "Missing initialize function in ExtensionModule " ~ ExtensionModule.stringof);
66 
67 	import core.sync.mutex;
68 	import core.thread;
69 
70 	import served.lsp.filereader;
71 	import served.lsp.jsonrpc;
72 	import served.lsp.protocol;
73 	import served.lsp.textdocumentmanager;
74 	import served.utils.async;
75 	import served.utils.events;
76 	import served.utils.fibermanager;
77 
78 	import painlessjson;
79 
80 	import std.datetime.stopwatch;
81 	import std.experimental.logger;
82 	import std.functional;
83 	import std.json;
84 
85 	import io = std.stdio;
86 
87 	alias members = ExtensionModule.members;
88 
89 	static if (is(typeof(ExtensionModule.shutdownRequested)))
90 		alias shutdownRequested = ExtensionModule.shutdownRequested;
91 	else
92 		bool shutdownRequested;
93 
94 	__gshared bool serverInitializeCalled = false;
95 
96 	mixin EventProcessor!(ExtensionModule, serverConfig.eventConfig) eventProcessor;
97 
98 	/// Calls a method associated with the given request type in the 
99 	ResponseMessage processRequest(RequestMessage msg)
100 	{
101 		debug(PerfTraceLog) mixin(traceStatistics(__FUNCTION__));
102 
103 		ResponseMessage res;
104 		res.id = msg.id;
105 		if (msg.method == "initialize" && !serverInitializeCalled)
106 		{
107 			trace("Initializing");
108 			res.result = ExtensionModule.initialize(msg.params.fromJSON!InitializeParams).toJSON;
109 			trace("Initialized");
110 			serverInitializeCalled = true;
111 			return res;
112 		}
113 		if (!serverInitializeCalled)
114 		{
115 			trace("Tried to call command without initializing");
116 			res.error = ResponseError(ErrorCode.serverNotInitialized);
117 			return res;
118 		}
119 
120 		size_t numHandlers;
121 		eventProcessor.emitProtocol!(protocolMethod, (name, callSymbol, uda) {
122 			numHandlers++;
123 		}, false)(msg.method, msg.params);
124 
125 		// trace("Function ", msg.method, " has ", numHandlers, " handlers");
126 		if (numHandlers == 0)
127 		{
128 			io.stderr.writeln(msg);
129 			res.error = ResponseError(ErrorCode.methodNotFound, "Request method " ~ msg.method ~ " not found");
130 			return res;
131 		}
132 
133 		JSONValue workDoneToken, partialResultToken;
134 		if (msg.params.type == JSONType.object)
135 		{
136 			if (auto doneToken = "workDoneToken" in msg.params)
137 				workDoneToken = *doneToken;
138 			if (auto partialToken = "partialResultToken" in msg.params)
139 				partialResultToken = *partialToken;
140 		}
141 
142 		int working = 0;
143 		JSONValue[] partialResults;
144 		void handlePartialWork(Symbol, Arguments)(Symbol fn, Arguments args)
145 		{
146 			import painlessjson : toJSON;
147 
148 			working++;
149 			pushFiber({
150 				scope (exit)
151 					working--;
152 				auto thisId = working;
153 				trace("Partial ", thisId, " / ", numHandlers, "...");
154 				auto result = fn(args.expand);
155 				trace("Partial ", thisId, " = ", result);
156 				JSONValue json = toJSON(result);
157 				if (partialResultToken == JSONValue.init)
158 					partialResults ~= json;
159 				else
160 					rpc.notifyMethod("$/progress", JSONValue([
161 						"token": partialResultToken,
162 						"value": json
163 					]));
164 				processRequestObservers(msg, json);
165 			});
166 		}
167 
168 		bool done, found;
169 		try
170 		{
171 			found = eventProcessor.emitProtocolRaw!(protocolMethod, (name, symbol, arguments, uda) {
172 				if (done)
173 					return;
174 
175 				trace("Calling ", name);
176 				alias RequestResultT = typeof(symbol(arguments.expand));
177 
178 				static if (is(RequestResultT : JSONValue))
179 				{
180 					auto requestResult = symbol(arguments.expand);
181 					res.result = requestResult;
182 					done = true;
183 					processRequestObservers(msg, requestResult);
184 				}
185 				else
186 				{
187 					static if (is(RequestResultT : T[], T))
188 					{
189 						if (numHandlers > 1)
190 						{
191 							handlePartialWork(symbol, arguments);
192 							return;
193 						}
194 					}
195 					else assert(numHandlers == 1, "Registered more than one "
196 						~ msg.method ~ " handler on non-partial method returning "
197 						~ RequestResultT.stringof);
198 					auto requestResult = symbol(arguments.expand);
199 					res.result = toJSON(requestResult);
200 					done = true;
201 					processRequestObservers(msg, requestResult);
202 				}
203 			}, false)(msg.method, msg.params);
204 		}
205 		catch (MethodException e)
206 		{
207 			res.result.nullify();
208 			res.error = e.error;
209 			return res;
210 		}
211 
212 		assert(found);
213 
214 		if (!done)
215 		{
216 			while (working > 0)
217 				Fiber.yield();
218 
219 			if (partialResultToken == JSONValue.init)
220 			{
221 				JSONValue[] combined;
222 				foreach (partial; partialResults)
223 				{
224 					assert(partial.type == JSONType.array);
225 					combined ~= partial.array;
226 				}
227 				res.result = JSONValue(combined);
228 			}
229 			else
230 			{
231 				JSONValue[] emptyArr;
232 				res.result = JSONValue(emptyArr);
233 			}
234 		}
235 
236 		return res;
237 	}
238 
239 	// calls @postProcotolMethod methods for the given request
240 	private void processRequestObservers(T)(RequestMessage msg, T result)
241 	{
242 		eventProcessor.emitProtocol!(postProtocolMethod, (name, callSymbol, uda) {
243 			try
244 			{
245 				callSymbol();
246 			}
247 			catch (MethodException e)
248 			{
249 				error("Error in post-protocolMethod: ", e);
250 			}
251 		}, false)(msg.method, msg.params, result);
252 	}
253 
254 	void processNotify(RequestMessage msg)
255 	{
256 		debug(PerfTraceLog) mixin(traceStatistics(__FUNCTION__));
257 
258 		// even though the spec says the process should not stop before exit, vscode-languageserver doesn't call exit after shutdown so we just shutdown on the next request.
259 		// this also makes sure we don't operate on invalid states and segfault.
260 		if (msg.method == "exit" || shutdownRequested)
261 		{
262 			rpc.stop();
263 			if (!shutdownRequested)
264 			{
265 				shutdownRequested = true;
266 				static if (is(typeof(ExtensionModule.shutdown)))
267 					ExtensionModule.shutdown();
268 			}
269 			return;
270 		}
271 
272 		static if (!is(typeof(ExtensionModule.shutdown)))
273 		{
274 			if (msg.method == "shutdown" && !shutdownRequested)
275 			{
276 				shutdownRequested = true;
277 				return;
278 			}
279 		}
280 
281 		if (!serverInitializeCalled)
282 		{
283 			trace("Tried to call notification without initializing");
284 			return;
285 		}
286 		documents.process(msg);
287 
288 		eventProcessor.emitProtocol!(protocolNotification, (name, callSymbol, uda) {
289 			try
290 			{
291 				callSymbol();
292 			}
293 			catch (MethodException e)
294 			{
295 				error("Failed notify: ", e);
296 			}
297 		}, false)(msg.method, msg.params);
298 	}
299 
300 	void delegate() gotRequest(RequestMessage msg)
301 	{
302 		return {
303 			ResponseMessage res;
304 			try
305 			{
306 				res = processRequest(msg);
307 			}
308 			catch (Exception e)
309 			{
310 				res.id = msg.id;
311 				res.error = ResponseError(e);
312 				res.error.code = ErrorCode.internalError;
313 			}
314 			catch (Throwable e)
315 			{
316 				res.id = msg.id;
317 				res.error = ResponseError(e);
318 				res.error.code = ErrorCode.internalError;
319 				rpc.window.showMessage(MessageType.error,
320 						"A fatal internal error occured in "
321 						~ serverConfig.productName
322 						~ " handling this request but it will attempt to keep running: "
323 						~ e.msg);
324 			}
325 			rpc.send(res);
326 		};
327 	}
328 
329 	void delegate() gotNotify(RequestMessage msg)
330 	{
331 		return {
332 			try
333 			{
334 				processNotify(msg);
335 			}
336 			catch (Exception e)
337 			{
338 				error("Failed processing notification: ", e);
339 			}
340 			catch (Throwable e)
341 			{
342 				error("Attempting to recover from fatal issue: ", e);
343 				rpc.window.showMessage(MessageType.error,
344 						"A fatal internal error has occured in "
345 						~ serverConfig.productName
346 						~ ", but it will attempt to keep running: "
347 						~ e.msg);
348 			}
349 		};
350 	}
351 
352 	__gshared FiberManager fibers;
353 	__gshared Mutex fibersMutex;
354 
355 	void pushFiber(T)(T callback, int pages = serverConfig.defaultPages, string file = __FILE__, int line = __LINE__)
356 	{
357 		synchronized (fibersMutex)
358 			fibers.put(new Fiber(callback, serverConfig.fiberPageSize * pages), file, line);
359 	}
360 
361 	RPCProcessor rpc;
362 	TextDocumentManager documents;
363 
364 	/// Runs the language server and returns true once it exited gracefully or
365 	/// false if it didn't exit gracefully.
366 	bool run()
367 	{
368 		auto input = newStdinReader();
369 		input.start();
370 		scope (exit)
371 			input.stop();
372 		for (int timeout = 10; timeout >= 0 && !input.isRunning; timeout--)
373 			Thread.sleep(1.msecs);
374 		trace("Started reading from stdin");
375 
376 		fibersMutex = new Mutex();
377 
378 		rpc = new RPCProcessor(input, stdout);
379 		rpc.call();
380 		trace("RPC started");
381 
382 		static if (serverConfig.gcCollectSeconds > 0)
383 		{
384 			int gcCollects, totalGcCollects;
385 			StopWatch gcInterval;
386 			gcInterval.start();
387 
388 			void collectGC()
389 			{
390 				import core.memory : GC;
391 
392 				auto before = GC.stats();
393 				StopWatch gcSpeed;
394 				gcSpeed.start();
395 
396 				GC.collect();
397 
398 				gcCollects++;
399 				totalGcCollects++;
400 				if (gcCollects > 5)
401 				{
402 					GC.minimize();
403 					gcCollects = 0;
404 				}
405 
406 				gcSpeed.stop();
407 				auto after = GC.stats();
408 
409 				if (before != after)
410 					tracef("GC run in %s. Freed %s bytes (%s bytes allocated, %s bytes available)", gcSpeed.peek,
411 							cast(long) before.usedSize - cast(long) after.usedSize, after.usedSize, after.freeSize);
412 				else
413 					trace("GC run in ", gcSpeed.peek);
414 
415 				gcInterval.reset();
416 			}
417 		}
418 
419 		scope (exit)
420 		{
421 			debug(PerfTraceLog)
422 			{
423 				import core.memory : GC;
424 				import std.stdio : File;
425 
426 				auto traceLog = File("served_trace.log", "w");
427 
428 				auto totalAllocated = GC.stats().allocatedInCurrentThread;
429 				auto profileStats = GC.profileStats();
430 
431 				traceLog.writeln("manually collected GC ", totalGcCollects, " times");
432 				traceLog.writeln("total ", profileStats.numCollections, " collections");
433 				traceLog.writeln("total collection time: ", profileStats.totalCollectionTime);
434 				traceLog.writeln("total pause time: ", profileStats.totalPauseTime);
435 				traceLog.writeln("max collection time: ", profileStats.maxCollectionTime);
436 				traceLog.writeln("max pause time: ", profileStats.maxPauseTime);
437 				traceLog.writeln("total allocated in main thread: ", totalAllocated);
438 				traceLog.writeln();
439 
440 				dumpTraceInfos(traceLog);
441 			}
442 		}
443 
444 		fibers ~= rpc;
445 
446 		spawnFiberImpl = (&pushFiber!(void delegate())).toDelegate;
447 
448 		static if (is(typeof(ExtensionModule.parallelMain)))
449 			pushFiber(&ExtensionModule.parallelMain);
450 
451 		while (rpc.state != Fiber.State.TERM)
452 		{
453 			while (rpc.hasData)
454 			{
455 				auto msg = rpc.poll;
456 				// Log on client side instead! (vscode setting: "serve-d.trace.server": "verbose")
457 				//trace("Message: ", msg);
458 				if (msg.id.hasData)
459 					pushFiber(gotRequest(msg));
460 				else
461 					pushFiber(gotNotify(msg));
462 			}
463 			Thread.sleep(10.msecs);
464 			synchronized (fibersMutex)
465 				fibers.call();
466 
467 			static if (serverConfig.gcCollectSeconds > 0)
468 			{
469 				if (gcInterval.peek > serverConfig.gcCollectSeconds.seconds)
470 				{
471 					collectGC();
472 				}
473 			}
474 		}
475 
476 		return shutdownRequested;
477 	}
478 }