1 /// Base APIs to create LSP servers quickly. Reconfigures stdin and stdout upon 2 /// importing to avoid accidental usage of the RPC channel. Changes stdin to a 3 /// null file and stdout to stderr. 4 module served.serverbase; 5 6 import served.utils.events : EventProcessorConfig; 7 8 import io = std.stdio; 9 10 /// Actual stdin/stdio as used for RPC communication. 11 __gshared io.File stdin, stdout; 12 shared static this() 13 { 14 stdin = io.stdin; 15 stdout = io.stdout; 16 version (Windows) 17 io.stdin = io.File("NUL", "r"); 18 else version (Posix) 19 io.stdin = io.File("/dev/null", "r"); 20 else 21 io.stderr.writeln("warning: no /dev/null implementation on this OS"); 22 io.stdout = io.stderr; 23 } 24 25 struct LanguageServerConfig 26 { 27 int defaultPages = 20; 28 int fiberPageSize = 4096; 29 30 EventProcessorConfig eventConfig; 31 32 /// Product name to use in error messages 33 string productName = "unnamed lsp"; 34 35 /// If set to non-zero, call GC.collect every n seconds and GC.minimize 36 /// every gcMinimizeTimes-th call. Keeps track of cleaned up memory in 37 /// trace logs. 38 int gcCollectSeconds = 30; 39 /// ditto 40 int gcMinimizeTimes = 5; 41 } 42 43 // dumps a performance/GC trace log to served_trace.log 44 //debug = PerfTraceLog; 45 46 /// Utility to setup an RPC connection via stdin/stdout and route all requests 47 /// to methods defined in the given extension module. 48 /// 49 /// Params: 50 /// ExtensionModule = a module defining the following members: 51 /// - `members`: a compile time list of all members in all modules that should 52 /// be introspected to be called automatically on matching RPC commands. 53 /// - `InitializeResult initialize(InitializeParams)`: initialization method. 54 /// 55 /// Optional: 56 /// - `bool shutdownRequested`: a boolean that is set to true before the 57 /// `shutdown` method handler or earlier which will terminate the RPC loop 58 /// gracefully and wait for an `exit` notification to actually exit. 59 /// - `@protocolMethod("shutdown") JsonValue shutdown()`: the method called 60 /// when the client wants to shutdown the server. Can return anything, 61 /// recommended return value is `JsonValue(null)`. 62 /// - `parallelMain`: an optional method which is run alongside everything 63 /// else in parallel using fibers. Should yield as much as possible when 64 /// there is nothing to do. 65 mixin template LanguageServerRouter(alias ExtensionModule, LanguageServerConfig serverConfig = LanguageServerConfig.init) 66 { 67 static assert(is(typeof(ExtensionModule.members)), "Missing members field in ExtensionModule " ~ ExtensionModule.stringof); 68 static assert(is(typeof(ExtensionModule.initialize)), "Missing initialize function in ExtensionModule " ~ ExtensionModule.stringof); 69 70 import core.sync.mutex; 71 import core.thread; 72 73 import served.lsp.filereader; 74 import served.lsp.jsonops; 75 import served.lsp.jsonrpc; 76 import served.lsp.protocol; 77 import served.lsp.textdocumentmanager; 78 import served.utils.async; 79 import served.utils.events; 80 import served.utils.fibermanager; 81 82 import std.datetime.stopwatch; 83 import std.experimental.logger; 84 import std.functional; 85 import std.json; 86 87 import io = std.stdio; 88 89 alias members = ExtensionModule.members; 90 91 static if (is(typeof(ExtensionModule.shutdownRequested))) 92 alias shutdownRequested = ExtensionModule.shutdownRequested; 93 else 94 bool shutdownRequested; 95 96 __gshared bool serverInitializeCalled = false; 97 98 mixin EventProcessor!(ExtensionModule, serverConfig.eventConfig) eventProcessor; 99 100 /// Calls a method associated with the given request type in the 101 ResponseMessageRaw processRequest(RequestMessageRaw msg) 102 { 103 debug(PerfTraceLog) mixin(traceStatistics(__FUNCTION__)); 104 scope (failure) 105 error("failure in message ", msg); 106 107 ResponseMessageRaw res; 108 if (msg.id.isNone) 109 throw new Exception("Called processRequest on a notification"); 110 res.id = msg.id.deref; 111 if (msg.method == "initialize" && !serverInitializeCalled) 112 { 113 trace("Initializing"); 114 auto initParams = msg.paramsJson.deserializeJson!InitializeParams; 115 auto initResult = ExtensionModule.initialize(initParams); 116 eventProcessor.emitExtensionEvent!initializeHook(initParams, initResult); 117 eventProcessor.emitExtensionEvent!onInitialize(initParams); 118 res.resultJson = initResult.serializeJson; 119 trace("Initialized"); 120 serverInitializeCalled = true; 121 pushFiber({ 122 Fiber.yield(); 123 processRequestObservers(msg, initResult); 124 }); 125 return res; 126 } 127 128 static if (!is(typeof(ExtensionModule.shutdown))) 129 { 130 if (msg.method == "shutdown" && !shutdownRequested) 131 { 132 shutdownRequested = true; 133 res.resultJson = `null`; 134 return res; 135 } 136 } 137 138 if (!serverInitializeCalled && msg.method != "shutdown") 139 { 140 trace("Tried to call command without initializing"); 141 res.error = ResponseError(ErrorCode.serverNotInitialized); 142 return res; 143 } 144 145 size_t numHandlers; 146 eventProcessor.emitProtocol!(protocolMethod, (name, callSymbol, uda) { 147 numHandlers++; 148 }, false)(msg.method, msg.paramsJson); 149 150 // trace("Function ", msg.method, " has ", numHandlers, " handlers"); 151 if (numHandlers == 0) 152 { 153 io.stderr.writeln(msg); 154 res.error = ResponseError(ErrorCode.methodNotFound, "Request method " ~ msg.method ~ " not found"); 155 return res; 156 } 157 158 string workDoneToken, partialResultToken; 159 if (msg.paramsJson.looksLikeJsonObject) 160 { 161 auto v = msg.paramsJson.parseKeySlices!("workDoneToken", "partialResultToken"); 162 workDoneToken = v.workDoneToken; 163 partialResultToken = v.partialResultToken; 164 } 165 166 int working = 0; 167 string[] partialResults; 168 void handlePartialWork(Symbol, Arguments)(Symbol fn, Arguments args) 169 { 170 working++; 171 pushFiber({ 172 scope (exit) 173 working--; 174 auto thisId = working; 175 trace("Partial ", thisId, " / ", numHandlers, "..."); 176 auto result = fn(args.expand); 177 trace("Partial ", thisId, " = ", result); 178 auto json = result.serializeJson; 179 if (!partialResultToken.length) 180 partialResults ~= json; 181 else 182 rpc.notifyProgressRaw(partialResultToken, json); 183 processRequestObservers(msg, result); 184 }); 185 } 186 187 bool done, found; 188 try 189 { 190 found = eventProcessor.emitProtocolRaw!(protocolMethod, (name, symbol, arguments, uda) { 191 if (done) 192 return; 193 194 trace("Calling request method ", name); 195 alias RequestResultT = typeof(symbol(arguments.expand)); 196 197 static if (is(RequestResultT : JsonValue)) 198 { 199 auto requestResult = symbol(arguments.expand); 200 res.resultJson = requestResult.serializeJson; 201 done = true; 202 processRequestObservers(msg, requestResult); 203 } 204 else 205 { 206 static if (is(RequestResultT : T[], T)) 207 { 208 if (numHandlers > 1) 209 { 210 handlePartialWork(symbol, arguments); 211 return; 212 } 213 } 214 else assert(numHandlers == 1, "Registered more than one " 215 ~ msg.method ~ " handler on non-partial method returning " 216 ~ RequestResultT.stringof); 217 auto requestResult = symbol(arguments.expand); 218 res.resultJson = requestResult.serializeJson; 219 done = true; 220 processRequestObservers(msg, requestResult); 221 } 222 }, false)(msg.method, msg.paramsJson); 223 } 224 catch (MethodException e) 225 { 226 res.resultJson = null; 227 res.error = e.error; 228 return res; 229 } 230 231 assert(found); 232 233 if (!done) 234 { 235 while (working > 0) 236 Fiber.yield(); 237 238 if (!partialResultToken.length) 239 { 240 size_t reservedLength = 1 + partialResults.length; 241 foreach (partial; partialResults) 242 { 243 assert(partial.looksLikeJsonArray); 244 reservedLength += partial.length - 2; 245 } 246 char[] resJson = new char[reservedLength]; 247 size_t i = 0; 248 resJson.ptr[i++] = '['; 249 foreach (partial; partialResults) 250 { 251 assert(i + partial.length - 2 < resJson.length); 252 resJson.ptr[i .. i += (partial.length - 2)] = partial[1 .. $ - 1]; 253 resJson.ptr[i++] = ','; 254 } 255 assert(i == resJson.length); 256 resJson.ptr[reservedLength - 1] = ']'; 257 res.resultJson = cast(string)resJson; 258 } 259 else 260 { 261 res.resultJson = `[]`; 262 } 263 } 264 265 return res; 266 } 267 268 // calls @postProcotolMethod methods for the given request 269 private void processRequestObservers(T)(RequestMessageRaw msg, T result) 270 { 271 eventProcessor.emitProtocol!(postProtocolMethod, (name, callSymbol, uda) { 272 trace("Calling post-request method ", name); 273 try 274 { 275 callSymbol(); 276 } 277 catch (MethodException e) 278 { 279 error("Error in post-protocolMethod: ", e); 280 } 281 }, false)(msg.method, msg.paramsJson, result); 282 } 283 284 void processNotify(RequestMessageRaw msg) 285 { 286 debug(PerfTraceLog) mixin(traceStatistics(__FUNCTION__)); 287 288 // even though the spec says the process should not stop before exit, vscode-languageserver doesn't call exit after shutdown so we just shutdown on the next request. 289 // this also makes sure we don't operate on invalid states and segfault. 290 if (msg.method == "exit" || shutdownRequested) 291 { 292 rpc.stop(); 293 if (!shutdownRequested) 294 { 295 shutdownRequested = true; 296 static if (is(typeof(ExtensionModule.shutdown))) 297 ExtensionModule.shutdown(); 298 } 299 return; 300 } 301 302 if (!serverInitializeCalled) 303 { 304 trace("Tried to call notification without initializing"); 305 return; 306 } 307 documents.process(msg); 308 309 bool gotAny = eventProcessor.emitProtocol!(protocolNotification, (name, callSymbol, uda) { 310 trace("Calling notification method ", name); 311 try 312 { 313 callSymbol(); 314 } 315 catch (MethodException e) 316 { 317 error("Failed notify: ", e); 318 } 319 }, false)(msg.method, msg.paramsJson); 320 321 if (!gotAny) 322 trace("No handlers for notification: ", msg); 323 } 324 325 void delegate() gotRequest(RequestMessageRaw msg) 326 { 327 return { 328 ResponseMessageRaw res; 329 try 330 { 331 res = processRequest(msg); 332 } 333 catch (Exception e) 334 { 335 if (!msg.id.isNone) 336 res.id = msg.id.deref; 337 auto err = ResponseError(e); 338 err.code = ErrorCode.internalError; 339 res.error = err; 340 } 341 catch (Throwable e) 342 { 343 if (!msg.id.isNone) 344 res.id = msg.id.deref; 345 auto err = ResponseError(e); 346 err.code = ErrorCode.internalError; 347 res.error = err; 348 rpc.window.showMessage(MessageType.error, 349 "A fatal internal error occured in " 350 ~ serverConfig.productName 351 ~ " handling this request but it will attempt to keep running: " 352 ~ e.msg); 353 } 354 rpc.send(res); 355 }; 356 } 357 358 void delegate() gotNotify(RequestMessageRaw msg) 359 { 360 return { 361 try 362 { 363 processNotify(msg); 364 } 365 catch (Exception e) 366 { 367 error("Failed processing notification: ", e); 368 } 369 catch (Throwable e) 370 { 371 error("Attempting to recover from fatal issue: ", e); 372 rpc.window.showMessage(MessageType.error, 373 "A fatal internal error has occured in " 374 ~ serverConfig.productName 375 ~ ", but it will attempt to keep running: " 376 ~ e.msg); 377 } 378 }; 379 } 380 381 __gshared FiberManager fibers; 382 __gshared Mutex fibersMutex; 383 384 void pushFiber(T)(T callback, int pages = serverConfig.defaultPages, string file = __FILE__, int line = __LINE__) 385 { 386 synchronized (fibersMutex) 387 fibers.put(new Fiber(callback, serverConfig.fiberPageSize * pages), file, line); 388 } 389 390 RPCProcessor rpc; 391 TextDocumentManager documents; 392 393 /// Runs the language server and returns true once it exited gracefully or 394 /// false if it didn't exit gracefully. 395 bool run() 396 { 397 auto input = newStdinReader(); 398 input.start(); 399 scope (exit) 400 input.stop(); 401 for (int timeout = 10; timeout >= 0 && !input.isRunning; timeout--) 402 Thread.sleep(1.msecs); 403 trace("Started reading from stdin"); 404 405 rpc = new RPCProcessor(input, stdout); 406 rpc.call(); 407 trace("RPC started"); 408 return runImpl(); 409 } 410 411 /// Same as `run`, assumes `rpc` is initialized and ready 412 bool runImpl() 413 { 414 fibersMutex = new Mutex(); 415 416 static if (serverConfig.gcCollectSeconds > 0) 417 { 418 int gcCollects, totalGcCollects; 419 StopWatch gcInterval; 420 gcInterval.start(); 421 422 void collectGC() 423 { 424 import core.memory : GC; 425 426 auto before = GC.stats(); 427 StopWatch gcSpeed; 428 gcSpeed.start(); 429 430 GC.collect(); 431 432 totalGcCollects++; 433 static if (serverConfig.gcMinimizeTimes > 0) 434 { 435 gcCollects++; 436 if (gcCollects >= serverConfig.gcMinimizeTimes) 437 { 438 GC.minimize(); 439 gcCollects = 0; 440 } 441 } 442 443 gcSpeed.stop(); 444 auto after = GC.stats(); 445 446 if (before != after) 447 tracef("GC run in %s. Freed %s bytes (%s bytes allocated, %s bytes available)", gcSpeed.peek, 448 cast(long) before.usedSize - cast(long) after.usedSize, after.usedSize, after.freeSize); 449 else 450 trace("GC run in ", gcSpeed.peek); 451 452 gcInterval.reset(); 453 } 454 } 455 456 scope (exit) 457 { 458 debug(PerfTraceLog) 459 { 460 import core.memory : GC; 461 import std.stdio : File; 462 463 auto traceLog = File("served_trace.log", "w"); 464 465 auto totalAllocated = GC.stats().allocatedInCurrentThread; 466 auto profileStats = GC.profileStats(); 467 468 traceLog.writeln("manually collected GC ", totalGcCollects, " times"); 469 traceLog.writeln("total ", profileStats.numCollections, " collections"); 470 traceLog.writeln("total collection time: ", profileStats.totalCollectionTime); 471 traceLog.writeln("total pause time: ", profileStats.totalPauseTime); 472 traceLog.writeln("max collection time: ", profileStats.maxCollectionTime); 473 traceLog.writeln("max pause time: ", profileStats.maxPauseTime); 474 traceLog.writeln("total allocated in main thread: ", totalAllocated); 475 traceLog.writeln(); 476 477 dumpTraceInfos(traceLog); 478 } 479 } 480 481 fibers ~= rpc; 482 483 spawnFiberImpl = (&pushFiber!(void delegate())).toDelegate; 484 defaultFiberPages = serverConfig.defaultPages; 485 486 static if (is(typeof(ExtensionModule.parallelMain))) 487 pushFiber(&ExtensionModule.parallelMain); 488 489 while (rpc.state != Fiber.State.TERM) 490 { 491 while (rpc.hasData) 492 { 493 auto msg = rpc.poll; 494 // Log on client side instead! (vscode setting: "serve-d.trace.server": "verbose") 495 //trace("Message: ", msg); 496 if (!msg.id.isNone) 497 pushFiber(gotRequest(msg)); 498 else 499 pushFiber(gotNotify(msg)); 500 } 501 Thread.sleep(10.msecs); 502 synchronized (fibersMutex) 503 fibers.call(); 504 505 static if (serverConfig.gcCollectSeconds > 0) 506 { 507 if (gcInterval.peek > serverConfig.gcCollectSeconds.seconds) 508 { 509 collectGC(); 510 } 511 } 512 } 513 514 return shutdownRequested; 515 } 516 } 517 518 unittest 519 { 520 import core.thread; 521 import core.time; 522 523 import std.conv; 524 import std.experimental.logger; 525 import std.stdio; 526 527 import served.lsp.jsonrpc; 528 import served.lsp.protocol; 529 import served.utils.events; 530 531 static struct CustomInitializeResult 532 { 533 bool calledA; 534 bool calledB; 535 bool calledC; 536 bool sanityFalse; 537 } 538 539 __gshared static int calledCustomNotify; 540 541 static struct UTServer 542 { 543 static: 544 alias members = __traits(derivedMembers, UTServer); 545 546 CustomInitializeResult initialize(InitializeParams params) 547 { 548 CustomInitializeResult res; 549 res.calledA = true; 550 return res; 551 } 552 553 @initializeHook 554 void myInitHook1(InitializeParams params, ref CustomInitializeResult result) 555 { 556 assert(result.calledA); 557 assert(!result.calledB); 558 assert(!result.sanityFalse); 559 560 result.calledB = true; 561 } 562 563 @initializeHook 564 void myInitHook2(InitializeParams params, ref CustomInitializeResult result) 565 { 566 assert(result.calledA); 567 assert(!result.calledC); 568 assert(!result.sanityFalse); 569 570 result.calledC = true; 571 } 572 573 @protocolMethod("textDocument/documentColor") 574 int myMethod1(DocumentColorParams c) 575 { 576 return 4 + cast(int)c.textDocument.uri.length; 577 } 578 579 static struct NotifyParams 580 { 581 int i; 582 } 583 584 @protocolNotification("custom/notify") 585 void myMethod2(NotifyParams params) 586 { 587 calledCustomNotify = 4 + params.i; 588 trace("myMethod2 -> ", calledCustomNotify, " - ptr: ", &calledCustomNotify); 589 } 590 } 591 592 // we get a bunch of deprecations because of dual-context, but I don't think we can do anything about these. 593 mixin LanguageServerRouter!(UTServer) server; 594 595 globalLogLevel = LogLevel.trace; 596 static if (__VERSION__ < 2101) 597 sharedLog = new FileLogger(io.stderr); 598 else 599 sharedLog = (() @trusted => cast(shared) new FileLogger(io.stderr))(); 600 601 MockRPC mockRPC; 602 mockRPC.testRPC((rpc) { 603 server.rpc = rpc; 604 bool started; 605 bool exitSuccess; 606 auto t = new Thread({ 607 started = true; 608 try 609 { 610 exitSuccess = server.runImpl(); 611 } 612 catch (Throwable t) 613 { 614 import std.stdio; 615 616 stderr.writeln("Fatal: mockRPC crashed: ", t); 617 } 618 }); 619 t.start(); 620 do { 621 Thread.sleep(10.msecs); 622 } while (!started); 623 // give it a little more time 624 Thread.sleep(200.msecs); 625 626 trace("Started mock RPC"); 627 mockRPC.writePacket(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"processId":null,"rootUri":"file:///","capabilities":{}}}`); 628 Thread.sleep(200.msecs); 629 assert(server.serverInitializeCalled); 630 trace("Initialized"); 631 632 auto resObj = ResponseMessageRaw.deserialize(mockRPC.readPacket()); 633 assert(resObj.error.isNone); 634 635 auto initResult = resObj.resultJson.deserializeJson!CustomInitializeResult; 636 637 assert(initResult.calledA); 638 assert(initResult.calledB); 639 assert(initResult.calledC); 640 assert(!initResult.sanityFalse); 641 trace("Initialize OK"); 642 643 mockRPC.writePacket(`{"jsonrpc":"2.0","id":1,"method":"textDocument/documentColor","params":{"textDocument":{"uri":"a"}}}`); 644 resObj = ResponseMessageRaw.deserialize(mockRPC.readPacket()); 645 assert(resObj.resultJson == `5`); 646 647 assert(!calledCustomNotify); 648 mockRPC.writePacket(`{"jsonrpc":"2.0","method":"custom/notify","params":{"i":4}}`); 649 Thread.sleep(200.msecs); 650 assert(calledCustomNotify == 8, 651 text("calledCustomNotify = ", calledCustomNotify, " - ptr: ", &calledCustomNotify)); 652 653 mockRPC.writePacket(`{"jsonrpc":"2.0","id":1,"method":"shutdown","params":{}}`); 654 mockRPC.readPacket(); 655 mockRPC.writePacket(`{"jsonrpc":"2.0","method":"exit","params":{}}`); 656 657 t.join(); 658 assert(exitSuccess); 659 }); 660 }