1 /// Base APIs to create LSP servers quickly. Reconfigures stdin and stdout upon
2 /// importing to avoid accidental usage of the RPC channel. Changes stdin to a
3 /// null file and stdout to stderr.
4 module served.serverbase;
5 
6 import served.utils.events : EventProcessorConfig;
7 
8 import io = std.stdio;
9 
10 /// Actual stdin/stdio as used for RPC communication.
11 __gshared io.File stdin, stdout;
12 shared static this()
13 {
14 	stdin = io.stdin;
15 	stdout = io.stdout;
16 	version (Windows)
17 		io.stdin = io.File("NUL", "r");
18 	else version (Posix)
19 		io.stdin = io.File("/dev/null", "r");
20 	else
21 		io.stderr.writeln("warning: no /dev/null implementation on this OS");
22 	io.stdout = io.stderr;
23 }
24 
25 struct LanguageServerConfig
26 {
27 	int defaultPages = 20;
28 	int fiberPageSize = 4096;
29 
30 	EventProcessorConfig eventConfig;
31 
32 	/// Product name to use in error messages
33 	string productName = "unnamed lsp";
34 
35 	/// If set to non-zero, call GC.collect every n seconds and GC.minimize
36 	/// every gcMinimizeTimes-th call. Keeps track of cleaned up memory in
37 	/// trace logs.
38 	int gcCollectSeconds = 30;
39 	/// ditto
40 	int gcMinimizeTimes = 5;
41 }
42 
43 // dumps a performance/GC trace log to served_trace.log
44 //debug = PerfTraceLog;
45 
46 /// Utility to setup an RPC connection via stdin/stdout and route all requests
47 /// to methods defined in the given extension module.
48 ///
49 /// Params:
50 ///   ExtensionModule = a module defining the following members:
51 ///   - `members`: a compile time list of all members in all modules that should
52 ///     be introspected to be called automatically on matching RPC commands.
53 ///   - `InitializeResult initialize(InitializeParams)`: initialization method.
54 ///
55 ///   Optional:
56 ///   - `bool shutdownRequested`: a boolean that is set to true before the
57 ///     `shutdown` method handler or earlier which will terminate the RPC loop
58 ///     gracefully and wait for an `exit` notification to actually exit.
59 ///   - `@protocolMethod("shutdown") JsonValue shutdown()`: the method called
60 ///     when the client wants to shutdown the server. Can return anything,
61 ///     recommended return value is `JsonValue(null)`.
62 ///   - `parallelMain`: an optional method which is run alongside everything
63 ///     else in parallel using fibers. Should yield as much as possible when
64 ///     there is nothing to do.
65 mixin template LanguageServerRouter(alias ExtensionModule, LanguageServerConfig serverConfig = LanguageServerConfig.init)
66 {
67 	static assert(is(typeof(ExtensionModule.members)), "Missing members field in ExtensionModule " ~ ExtensionModule.stringof);
68 	static assert(is(typeof(ExtensionModule.initialize)), "Missing initialize function in ExtensionModule " ~ ExtensionModule.stringof);
69 
70 	import core.sync.mutex;
71 	import core.thread;
72 
73 	import served.lsp.filereader;
74 	import served.lsp.jsonops;
75 	import served.lsp.jsonrpc;
76 	import served.lsp.protocol;
77 	import served.lsp.textdocumentmanager;
78 	import served.utils.async;
79 	import served.utils.events;
80 	import served.utils.fibermanager;
81 
82 	import std.datetime.stopwatch;
83 	import std.experimental.logger;
84 	import std.functional;
85 	import std.json;
86 
87 	import io = std.stdio;
88 
89 	alias members = ExtensionModule.members;
90 
91 	static if (is(typeof(ExtensionModule.shutdownRequested)))
92 		alias shutdownRequested = ExtensionModule.shutdownRequested;
93 	else
94 		bool shutdownRequested;
95 
96 	__gshared bool serverInitializeCalled = false;
97 
98 	mixin EventProcessor!(ExtensionModule, serverConfig.eventConfig) eventProcessor;
99 
100 	/// Calls a method associated with the given request type in the 
101 	ResponseMessageRaw processRequest(RequestMessageRaw msg)
102 	{
103 		debug(PerfTraceLog) mixin(traceStatistics(__FUNCTION__));
104 		scope (failure)
105 			error("failure in message ", msg);
106 
107 		ResponseMessageRaw res;
108 		if (msg.id.isNone)
109 			throw new Exception("Called processRequest on a notification");
110 		res.id = msg.id.deref;
111 		if (msg.method == "initialize" && !serverInitializeCalled)
112 		{
113 			trace("Initializing");
114 			auto initParams = msg.paramsJson.deserializeJson!InitializeParams;
115 			auto initResult = ExtensionModule.initialize(initParams);
116 			eventProcessor.emitExtensionEvent!initializeHook(initParams, initResult);
117 			eventProcessor.emitExtensionEvent!onInitialize(initParams);
118 			res.resultJson = initResult.serializeJson;
119 			trace("Initialized");
120 			serverInitializeCalled = true;
121 			pushFiber({
122 				Fiber.yield();
123 				processRequestObservers(msg, initResult);
124 			});
125 			return res;
126 		}
127 
128 		static if (!is(typeof(ExtensionModule.shutdown)))
129 		{
130 			if (msg.method == "shutdown" && !shutdownRequested)
131 			{
132 				shutdownRequested = true;
133 				res.resultJson = `null`;
134 				return res;
135 			}
136 		}
137 
138 		if (!serverInitializeCalled && msg.method != "shutdown")
139 		{
140 			trace("Tried to call command without initializing");
141 			res.error = ResponseError(ErrorCode.serverNotInitialized);
142 			return res;
143 		}
144 
145 		size_t numHandlers;
146 		eventProcessor.emitProtocol!(protocolMethod, (name, callSymbol, uda) {
147 			numHandlers++;
148 		}, false)(msg.method, msg.paramsJson);
149 
150 		// trace("Function ", msg.method, " has ", numHandlers, " handlers");
151 		if (numHandlers == 0)
152 		{
153 			io.stderr.writeln(msg);
154 			res.error = ResponseError(ErrorCode.methodNotFound, "Request method " ~ msg.method ~ " not found");
155 			return res;
156 		}
157 
158 		string workDoneToken, partialResultToken;
159 		if (msg.paramsJson.looksLikeJsonObject)
160 		{
161 			auto v = msg.paramsJson.parseKeySlices!("workDoneToken", "partialResultToken");
162 			workDoneToken = v.workDoneToken;
163 			partialResultToken = v.partialResultToken;
164 		}
165 
166 		int working = 0;
167 		string[] partialResults;
168 		void handlePartialWork(Symbol, Arguments)(Symbol fn, Arguments args)
169 		{
170 			working++;
171 			pushFiber({
172 				scope (exit)
173 					working--;
174 				auto thisId = working;
175 				trace("Partial ", thisId, " / ", numHandlers, "...");
176 				auto result = fn(args.expand);
177 				trace("Partial ", thisId, " = ", result);
178 				auto json = result.serializeJson;
179 				if (!partialResultToken.length)
180 					partialResults ~= json;
181 				else
182 					rpc.notifyProgressRaw(partialResultToken, json);
183 				processRequestObservers(msg, result);
184 			});
185 		}
186 
187 		bool done, found;
188 		try
189 		{
190 			found = eventProcessor.emitProtocolRaw!(protocolMethod, (name, symbol, arguments, uda) {
191 				if (done)
192 					return;
193 
194 				trace("Calling request method ", name);
195 				alias RequestResultT = typeof(symbol(arguments.expand));
196 
197 				static if (is(RequestResultT : JsonValue))
198 				{
199 					auto requestResult = symbol(arguments.expand);
200 					res.resultJson = requestResult.serializeJson;
201 					done = true;
202 					processRequestObservers(msg, requestResult);
203 				}
204 				else
205 				{
206 					static if (is(RequestResultT : T[], T))
207 					{
208 						if (numHandlers > 1)
209 						{
210 							handlePartialWork(symbol, arguments);
211 							return;
212 						}
213 					}
214 					else assert(numHandlers == 1, "Registered more than one "
215 						~ msg.method ~ " handler on non-partial method returning "
216 						~ RequestResultT.stringof);
217 					auto requestResult = symbol(arguments.expand);
218 					res.resultJson = requestResult.serializeJson;
219 					done = true;
220 					processRequestObservers(msg, requestResult);
221 				}
222 			}, false)(msg.method, msg.paramsJson);
223 		}
224 		catch (MethodException e)
225 		{
226 			res.resultJson = null;
227 			res.error = e.error;
228 			return res;
229 		}
230 
231 		assert(found);
232 
233 		if (!done)
234 		{
235 			while (working > 0)
236 				Fiber.yield();
237 
238 			if (!partialResultToken.length)
239 			{
240 				size_t reservedLength = 1 + partialResults.length;
241 				foreach (partial; partialResults)
242 				{
243 					assert(partial.looksLikeJsonArray);
244 					reservedLength += partial.length - 2;
245 				}
246 				char[] resJson = new char[reservedLength];
247 				size_t i = 0;
248 				resJson.ptr[i++] = '[';
249 				foreach (partial; partialResults)
250 				{
251 					assert(i + partial.length - 2 < resJson.length);
252 					resJson.ptr[i .. i += (partial.length - 2)] = partial[1 .. $ - 1];
253 					resJson.ptr[i++] = ',';
254 				}
255 				assert(i == resJson.length);
256 				resJson.ptr[reservedLength - 1] = ']';
257 				res.resultJson = cast(string)resJson;
258 			}
259 			else
260 			{
261 				res.resultJson = `[]`;
262 			}
263 		}
264 
265 		return res;
266 	}
267 
268 	// calls @postProcotolMethod methods for the given request
269 	private void processRequestObservers(T)(RequestMessageRaw msg, T result)
270 	{
271 		eventProcessor.emitProtocol!(postProtocolMethod, (name, callSymbol, uda) {
272 			trace("Calling post-request method ", name);
273 			try
274 			{
275 				callSymbol();
276 			}
277 			catch (MethodException e)
278 			{
279 				error("Error in post-protocolMethod: ", e);
280 			}
281 		}, false)(msg.method, msg.paramsJson, result);
282 	}
283 
284 	void processNotify(RequestMessageRaw msg)
285 	{
286 		debug(PerfTraceLog) mixin(traceStatistics(__FUNCTION__));
287 
288 		// even though the spec says the process should not stop before exit, vscode-languageserver doesn't call exit after shutdown so we just shutdown on the next request.
289 		// this also makes sure we don't operate on invalid states and segfault.
290 		if (msg.method == "exit" || shutdownRequested)
291 		{
292 			rpc.stop();
293 			if (!shutdownRequested)
294 			{
295 				shutdownRequested = true;
296 				static if (is(typeof(ExtensionModule.shutdown)))
297 					ExtensionModule.shutdown();
298 			}
299 			return;
300 		}
301 
302 		if (!serverInitializeCalled)
303 		{
304 			trace("Tried to call notification without initializing");
305 			return;
306 		}
307 		documents.process(msg);
308 
309 		bool gotAny = eventProcessor.emitProtocol!(protocolNotification, (name, callSymbol, uda) {
310 			trace("Calling notification method ", name);
311 			try
312 			{
313 				callSymbol();
314 			}
315 			catch (MethodException e)
316 			{
317 				error("Failed notify: ", e);
318 			}
319 		}, false)(msg.method, msg.paramsJson);
320 
321 		if (!gotAny)
322 			trace("No handlers for notification: ", msg);
323 	}
324 
325 	void delegate() gotRequest(RequestMessageRaw msg)
326 	{
327 		return {
328 			ResponseMessageRaw res;
329 			try
330 			{
331 				res = processRequest(msg);
332 			}
333 			catch (Exception e)
334 			{
335 				if (!msg.id.isNone)
336 					res.id = msg.id.deref;
337 				auto err = ResponseError(e);
338 				err.code = ErrorCode.internalError;
339 				res.error = err;
340 			}
341 			catch (Throwable e)
342 			{
343 				if (!msg.id.isNone)
344 					res.id = msg.id.deref;
345 				auto err = ResponseError(e);
346 				err.code = ErrorCode.internalError;
347 				res.error = err;
348 				rpc.window.showMessage(MessageType.error,
349 						"A fatal internal error occured in "
350 						~ serverConfig.productName
351 						~ " handling this request but it will attempt to keep running: "
352 						~ e.msg);
353 			}
354 			rpc.send(res);
355 		};
356 	}
357 
358 	void delegate() gotNotify(RequestMessageRaw msg)
359 	{
360 		return {
361 			try
362 			{
363 				processNotify(msg);
364 			}
365 			catch (Exception e)
366 			{
367 				error("Failed processing notification: ", e);
368 			}
369 			catch (Throwable e)
370 			{
371 				error("Attempting to recover from fatal issue: ", e);
372 				rpc.window.showMessage(MessageType.error,
373 						"A fatal internal error has occured in "
374 						~ serverConfig.productName
375 						~ ", but it will attempt to keep running: "
376 						~ e.msg);
377 			}
378 		};
379 	}
380 
381 	__gshared FiberManager fibers;
382 	__gshared Mutex fibersMutex;
383 
384 	void pushFiber(T)(T callback, int pages = serverConfig.defaultPages, string file = __FILE__, int line = __LINE__)
385 	{
386 		synchronized (fibersMutex)
387 			fibers.put(new Fiber(callback, serverConfig.fiberPageSize * pages), file, line);
388 	}
389 
390 	RPCProcessor rpc;
391 	TextDocumentManager documents;
392 
393 	/// Runs the language server and returns true once it exited gracefully or
394 	/// false if it didn't exit gracefully.
395 	bool run()
396 	{
397 		auto input = newStdinReader();
398 		input.start();
399 		scope (exit)
400 			input.stop();
401 		for (int timeout = 10; timeout >= 0 && !input.isRunning; timeout--)
402 			Thread.sleep(1.msecs);
403 		trace("Started reading from stdin");
404 
405 		rpc = new RPCProcessor(input, stdout);
406 		rpc.call();
407 		trace("RPC started");
408 		return runImpl(); 
409 	}
410 
411 	/// Same as `run`, assumes `rpc` is initialized and ready
412 	bool runImpl()
413 	{
414 		fibersMutex = new Mutex();
415 
416 		static if (serverConfig.gcCollectSeconds > 0)
417 		{
418 			int gcCollects, totalGcCollects;
419 			StopWatch gcInterval;
420 			gcInterval.start();
421 
422 			void collectGC()
423 			{
424 				import core.memory : GC;
425 
426 				auto before = GC.stats();
427 				StopWatch gcSpeed;
428 				gcSpeed.start();
429 
430 				GC.collect();
431 
432 				totalGcCollects++;
433 				static if (serverConfig.gcMinimizeTimes > 0)
434 				{
435 					gcCollects++;
436 					if (gcCollects >= serverConfig.gcMinimizeTimes)
437 					{
438 						GC.minimize();
439 						gcCollects = 0;
440 					}
441 				}
442 
443 				gcSpeed.stop();
444 				auto after = GC.stats();
445 
446 				if (before != after)
447 					tracef("GC run in %s. Freed %s bytes (%s bytes allocated, %s bytes available)", gcSpeed.peek,
448 							cast(long) before.usedSize - cast(long) after.usedSize, after.usedSize, after.freeSize);
449 				else
450 					trace("GC run in ", gcSpeed.peek);
451 
452 				gcInterval.reset();
453 			}
454 		}
455 
456 		scope (exit)
457 		{
458 			debug(PerfTraceLog)
459 			{
460 				import core.memory : GC;
461 				import std.stdio : File;
462 
463 				auto traceLog = File("served_trace.log", "w");
464 
465 				auto totalAllocated = GC.stats().allocatedInCurrentThread;
466 				auto profileStats = GC.profileStats();
467 
468 				traceLog.writeln("manually collected GC ", totalGcCollects, " times");
469 				traceLog.writeln("total ", profileStats.numCollections, " collections");
470 				traceLog.writeln("total collection time: ", profileStats.totalCollectionTime);
471 				traceLog.writeln("total pause time: ", profileStats.totalPauseTime);
472 				traceLog.writeln("max collection time: ", profileStats.maxCollectionTime);
473 				traceLog.writeln("max pause time: ", profileStats.maxPauseTime);
474 				traceLog.writeln("total allocated in main thread: ", totalAllocated);
475 				traceLog.writeln();
476 
477 				dumpTraceInfos(traceLog);
478 			}
479 		}
480 
481 		fibers ~= rpc;
482 
483 		spawnFiberImpl = (&pushFiber!(void delegate())).toDelegate;
484 		defaultFiberPages = serverConfig.defaultPages;
485 
486 		static if (is(typeof(ExtensionModule.parallelMain)))
487 			pushFiber(&ExtensionModule.parallelMain);
488 
489 		while (rpc.state != Fiber.State.TERM)
490 		{
491 			while (rpc.hasData)
492 			{
493 				auto msg = rpc.poll;
494 				// Log on client side instead! (vscode setting: "serve-d.trace.server": "verbose")
495 				//trace("Message: ", msg);
496 				if (!msg.id.isNone)
497 					pushFiber(gotRequest(msg));
498 				else
499 					pushFiber(gotNotify(msg));
500 			}
501 			Thread.sleep(10.msecs);
502 			synchronized (fibersMutex)
503 				fibers.call();
504 
505 			static if (serverConfig.gcCollectSeconds > 0)
506 			{
507 				if (gcInterval.peek > serverConfig.gcCollectSeconds.seconds)
508 				{
509 					collectGC();
510 				}
511 			}
512 		}
513 
514 		return shutdownRequested;
515 	}
516 }
517 
518 unittest
519 {
520 	import core.thread;
521 	import core.time;
522 
523 	import std.conv;
524 	import std.experimental.logger;
525 	import std.stdio;
526 
527 	import served.lsp.jsonrpc;
528 	import served.lsp.protocol;
529 	import served.utils.events;
530 
531 	static struct CustomInitializeResult
532 	{
533 		bool calledA;
534 		bool calledB;
535 		bool calledC;
536 		bool sanityFalse;
537 	}
538 
539 	__gshared static int calledCustomNotify;
540 
541 	static struct UTServer
542 	{
543 	static:
544 		alias members = __traits(derivedMembers, UTServer);
545 
546 		CustomInitializeResult initialize(InitializeParams params)
547 		{
548 			CustomInitializeResult res;
549 			res.calledA = true;
550 			return res;
551 		}
552 
553 		@initializeHook
554 		void myInitHook1(InitializeParams params, ref CustomInitializeResult result)
555 		{
556 			assert(result.calledA);
557 			assert(!result.calledB);
558 			assert(!result.sanityFalse);
559 
560 			result.calledB = true;
561 		}
562 
563 		@initializeHook
564 		void myInitHook2(InitializeParams params, ref CustomInitializeResult result)
565 		{
566 			assert(result.calledA);
567 			assert(!result.calledC);
568 			assert(!result.sanityFalse);
569 
570 			result.calledC = true;
571 		}
572 
573 		@protocolMethod("textDocument/documentColor")
574 		int myMethod1(DocumentColorParams c)
575 		{
576 			return 4 + cast(int)c.textDocument.uri.length;
577 		}
578 
579 		static struct NotifyParams
580 		{
581 			int i;
582 		}
583 
584 		@protocolNotification("custom/notify")
585 		void myMethod2(NotifyParams params)
586 		{
587 			calledCustomNotify = 4 + params.i;
588 			trace("myMethod2 -> ", calledCustomNotify, " - ptr: ", &calledCustomNotify);
589 		}
590 	}
591 
592 	// we get a bunch of deprecations because of dual-context, but I don't think we can do anything about these.
593 	mixin LanguageServerRouter!(UTServer) server;
594 
595 	globalLogLevel = LogLevel.trace;
596 	static if (__VERSION__ < 2101)
597 		sharedLog = new FileLogger(io.stderr);
598 	else
599 		sharedLog = (() @trusted => cast(shared) new FileLogger(io.stderr))();
600 
601 	MockRPC mockRPC;
602 	mockRPC.testRPC((rpc) {
603 		server.rpc = rpc;
604 		bool started;
605 		bool exitSuccess;
606 		auto t = new Thread({
607 			started = true;
608 			try
609 			{
610 				exitSuccess = server.runImpl();
611 			}
612 			catch (Throwable t)
613 			{
614 				import std.stdio;
615 
616 				stderr.writeln("Fatal: mockRPC crashed: ", t);
617 			}
618 		});
619 		t.start();
620 		do {
621 			Thread.sleep(10.msecs);
622 		} while (!started);
623 		// give it a little more time
624 		Thread.sleep(200.msecs);
625 
626 		trace("Started mock RPC");
627 		mockRPC.writePacket(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"processId":null,"rootUri":"file:///","capabilities":{}}}`);
628 		Thread.sleep(200.msecs);
629 		assert(server.serverInitializeCalled);
630 		trace("Initialized");
631 
632 		auto resObj = ResponseMessageRaw.deserialize(mockRPC.readPacket());
633 		assert(resObj.error.isNone);
634 
635 		auto initResult = resObj.resultJson.deserializeJson!CustomInitializeResult;
636 
637 		assert(initResult.calledA);
638 		assert(initResult.calledB);
639 		assert(initResult.calledC);
640 		assert(!initResult.sanityFalse);
641 		trace("Initialize OK");
642 
643 		mockRPC.writePacket(`{"jsonrpc":"2.0","id":1,"method":"textDocument/documentColor","params":{"textDocument":{"uri":"a"}}}`);
644 		resObj = ResponseMessageRaw.deserialize(mockRPC.readPacket());
645 		assert(resObj.resultJson == `5`);
646 
647 		assert(!calledCustomNotify);
648 		mockRPC.writePacket(`{"jsonrpc":"2.0","method":"custom/notify","params":{"i":4}}`);
649 		Thread.sleep(200.msecs);
650 		assert(calledCustomNotify == 8,
651 			text("calledCustomNotify = ", calledCustomNotify, " - ptr: ", &calledCustomNotify));
652 
653 		mockRPC.writePacket(`{"jsonrpc":"2.0","id":1,"method":"shutdown","params":{}}`);
654 		mockRPC.readPacket();
655 		mockRPC.writePacket(`{"jsonrpc":"2.0","method":"exit","params":{}}`);
656 
657 		t.join();
658 		assert(exitSuccess);
659 	});
660 }