1 module served.utils.events;
2 
3 /// Called for requests (not notifications) from the client to the server. This
4 /// UDA must be used at most once per method for regular methods. For methods
5 /// returning arrays (T[]) it's possible to register multiple functions with the
6 /// same method. In this case, if the client supports it, partial results will
7 /// be sent for each returning method, meaning the results are streamed. In case
8 /// the client does not support partial methods, all results will be
9 /// concatenated together and returned as one.
10 struct protocolMethod
11 {
12 	string method;
13 }
14 
15 /// Called after the @protocolMethod for this method is handled. May have as
16 /// many handlers registered as needed. When the actual protocol method is a
17 /// partial method (multiple handlers, returning array) this will be ran on each
18 /// chunk returned by every handler. In that case the handler will be run
19 /// multiple times on different fibers.
20 struct postProtocolMethod
21 {
22 	string method;
23 }
24 
25 /// UDA to annotate a request or notification parameter with to supress linting
26 /// warnings.
27 enum nonStandard;
28 
29 struct protocolNotification
30 {
31 	string method;
32 }
33 
34 struct EventProcessorConfig
35 {
36 	string[] allowedDuplicateMethods = ["object", "served", "std", "io", "workspaced", "fs"];
37 }
38 
39 /// Hooks into initialization, possibly manipulating the InitializeResponse.
40 /// Called after the extension entry point `initialize()` method, but before the
41 /// initialize response was sent to the client.
42 ///
43 /// If it's desired not to stall the initialization routine, use
44 /// `@postProtocolMethod("initialized")` instead of these UDAs, which runs in a
45 /// separate fiber after the response has been sent. Warning: other requests and
46 /// notifications may have been called within this switching time window, so
47 /// if these functions depend on what is being called in the initialize hook,
48 /// they will break.
49 ///
50 /// Annotated method is expected to have this type signature:
51 /// ```d
52 /// @initializeHook
53 /// void myInitHook(InitializeParams params, ref InitializeResult result);
54 /// @onInitialize
55 /// void otherHook(InitializeParams params);
56 /// ```
57 enum initializeHook;
58 /// ditto
59 enum onInitialize;
60 
61 /// Implements the event processor for a given extension module exposing a
62 /// `members` field defining all potential methods.
63 mixin template EventProcessor(alias ExtensionModule, EventProcessorConfig config = EventProcessorConfig.init)
64 {
65 	static if (__traits(compiles, { import core.lifetime : forward; }))
66 		import core.lifetime : forward;
67 	else
68 		import std.functional : forward;
69 
70 	import served.lsp.protocol;
71 
72 	import std.algorithm;
73 	import std.meta;
74 	import std.traits;
75 
76 	// duplicate method name check to avoid name clashes and unreadable error messages
77 	private static string[] findDuplicates(string[] fields)
78 	{
79 		string[] dups;
80 		Loop: foreach (i, field; fields)
81 		{
82 			static foreach (allowed; config.allowedDuplicateMethods)
83 				if (field == allowed)
84 					continue Loop;
85 
86 			if (fields[0 .. i].canFind(field) || fields[i + 1 .. $].canFind(field))
87 				dups ~= field;
88 		}
89 		return dups;
90 	}
91 
92 	enum duplicates = findDuplicates([ExtensionModule.members]);
93 	static if (duplicates.length > 0)
94 	{
95 		pragma(msg, "duplicates: ", duplicates);
96 		static assert(false, "Found duplicate method handlers of same name");
97 	}
98 
99 	enum lintWarnings = ctLintEvents();
100 	static if (lintWarnings.length > 0)
101 		pragma(msg, lintWarnings);
102 
103 	private static string ctLintEvents()
104 	{
105 		import std.string : chomp;
106 
107 		static bool isInvalidMethodName(string methodName, AllowedMethods[] allowed)
108 		{
109 			if (!allowed.length)
110 				return false;
111 
112 			foreach (a; allowed)
113 				foreach (m; a.methods)
114 					if (m == methodName)
115 						return false;
116 			return true;
117 		}
118 
119 		static string formatMethodNameWarning(string methodName, AllowedMethods[] allowed,
120 			string codeName, string file, size_t line, size_t column)
121 		{
122 			import std.conv : to;
123 
124 			string allowedStr = "";
125 			foreach (allow; allowed)
126 			{
127 				foreach (m; allow.methods)
128 				{
129 					if (allowedStr.length)
130 						allowedStr ~= ", ";
131 					allowedStr ~= "`" ~ m ~ "`";
132 				}
133 			}
134 
135 			return "\x1B[1m" ~ file ~ "(" ~ line.to!string ~ "," ~ column.to!string ~ "): \x1B[1;34mHint: \x1B[m"
136 				~ "method " ~ codeName ~ " listens for event `" ~ methodName
137 				~ "`, but the type has set allowed methods to " ~ allowedStr
138 				~ ".\n\t\tNote: check back with the LSP specification, in case this is wrongly tagged or annotate parameter with @nonStandard.\n";
139 		}
140 
141 		string lintResult;
142 		foreach (name; ExtensionModule.members)
143 		{
144 			static if (__traits(compiles, __traits(getMember, ExtensionModule, name)))
145 			{
146 				// AliasSeq to workaround AliasSeq members
147 				alias symbols = AliasSeq!(__traits(getMember, ExtensionModule, name));
148 				static if (symbols.length == 1 && hasUDA!(symbols[0], protocolMethod))
149 					enum methodName = getUDAs!(symbols[0], protocolMethod)[0].method;
150 				else static if (symbols.length == 1 && hasUDA!(symbols[0], protocolNotification))
151 					enum methodName = getUDAs!(symbols[0], protocolNotification)[0].method;
152 				else
153 					enum methodName = "";
154 
155 				static if (methodName.length)
156 				{
157 					alias symbol = symbols[0];
158 					static if (isSomeFunction!(symbol) && __traits(getProtection, symbol) == "public")
159 					{
160 						alias P = Parameters!symbol;
161 						static if (P.length == 1 && is(P[0] == struct)
162 							&& staticIndexOf!(nonStandard, __traits(getAttributes, P)) == -1)
163 						{
164 							enum allowedMethods = getUDAs!(P[0], AllowedMethods);
165 							static if (isInvalidMethodName(methodName, [allowedMethods]))
166 								lintResult ~= formatMethodNameWarning(methodName, [allowedMethods],
167 									name, __traits(getLocation, symbol));
168 						}
169 					}
170 				}
171 			}
172 		}
173 
174 		return lintResult.chomp("\n");
175 	}
176 
177 	/// Calls all protocol methods in `ExtensionModule` matching a certain method
178 	/// and method type.
179 	/// Params:
180 	///  UDA = The UDA to filter the methods with. This must define a string member
181 	///     called `method` which is compared with the runtime `method` argument.
182 	///  callback = The callback which is called for every matching function with
183 	///     the given UDA and method name. Called with arguments `(string name,
184 	///     void delegate() callSymbol, UDA uda)` where the `callSymbol` function is
185 	///     a parameterless function which automatically converts the JSON params
186 	///     and additional available arguments based on the method overload and
187 	///     calls it.
188 	///  returnFirst = If `true` the callback will be called at most once with any
189 	///     unspecified matching method. If `false` the callback will be called with
190 	///     all matching methods.
191 	///  method = the runtime method name to compare the UDA names with
192 	///  params = the JSON arguments for this protocol event, automatically
193 	///     converted to method arguments on demand.
194 	///  availableExtraArgs = static extra arguments available to pass to the method
195 	///     calls. `out`, `ref` and `lazy` are perserved given the method overloads.
196 	///     overloads may consume anywhere between 0 to Args.length of these
197 	///     arguments.
198 	/// Returns: `true` if any method has been called, `false` otherwise.
199 	bool emitProtocol(alias UDA, alias callback, bool returnFirst, Args...)(string method,
200 			string params, Args availableExtraArgs)
201 	{
202 		ensureImpure();
203 
204 		return iterateExtensionMethodsByUDA!(UDA, (name, symbol, uda) {
205 			if (uda.method == method)
206 			{
207 				debug (PerfTraceLog) mixin(traceStatistics(uda.method ~ ":" ~ name));
208 
209 				alias symbolArgs = Parameters!symbol;
210 
211 				auto callSymbol()
212 				{
213 					static if (symbolArgs.length == 0)
214 					{
215 						return symbol();
216 					}
217 					else static if (symbolArgs.length == 1)
218 					{
219 						return symbol(implParseParam!(symbolArgs[0])(params));
220 					}
221 					else static if (availableExtraArgs.length > 0
222 						&& symbolArgs.length <= 1 + availableExtraArgs.length)
223 					{
224 						return symbol(implParseParam!(symbolArgs[0])(params), forward!(
225 							availableExtraArgs[0 .. symbolArgs.length + -1]));
226 					}
227 					else
228 					{
229 						static assert(0, "Function for " ~ name ~ " can't have more than one argument");
230 					}
231 				}
232 
233 				callback(name, &callSymbol, uda);
234 				return true;
235 			}
236 			else
237 				return false;
238 		}, returnFirst);
239 	}
240 
241 	/// Same as emitProtocol, but for the callback instead of getting a delegate
242 	/// to call, you get a function pointer and a tuple with the arguments for
243 	/// each instantiation that can be expanded.
244 	///
245 	/// So the callback gets called like `callback(name, symbol, arguments, uda)`
246 	/// and the implementation can then call the symbol function using
247 	/// `symbol(arguments.expand)`.
248 	///
249 	/// This works around scoping issues and copies the arguments once more on
250 	/// invocation, causing ref/out parameters to get lost however. Allows to
251 	/// copy the arguments to other fibers for parallel processing.
252 	bool emitProtocolRaw(alias UDA, alias callback, bool returnFirst)(string method,
253 			string params)
254 	{
255 		import std.typecons : tuple;
256 		ensureImpure();
257 
258 		return iterateExtensionMethodsByUDA!(UDA, (name, symbol, uda) {
259 			if (uda.method == method)
260 			{
261 				debug (PerfTraceLog) mixin(traceStatistics(uda.method ~ ":" ~ name));
262 
263 				alias symbolArgs = Parameters!symbol;
264 
265 				static if (symbolArgs.length == 0)
266 				{
267 					auto arguments = tuple();
268 				}
269 				else static if (symbolArgs.length == 1)
270 				{
271 					auto arguments = tuple(implParseParam!(symbolArgs[0])(params));
272 				}
273 				else static if (availableExtraArgs.length > 0
274 					&& symbolArgs.length <= 1 + availableExtraArgs.length)
275 				{
276 					auto arguments = tuple(implParseParam!(symbolArgs[0])(params), forward!(
277 						availableExtraArgs[0 .. symbolArgs.length + -1]));
278 				}
279 				else
280 				{
281 					static assert(0, "Function for " ~ name ~ " can't have more than one argument");
282 				}
283 
284 				callback(name, symbol, arguments, uda);
285 				return true;
286 			}
287 			else
288 				return false;
289 		}, returnFirst);
290 	}
291 
292 	bool emitExtensionEvent(alias UDA, Args...)(auto ref Args args)
293 	{
294 		ensureImpure();
295 		return iterateExtensionMethodsByUDA!(UDA, (name, symbol, uda) {
296 			symbol(forward!args);
297 			return true;
298 		}, false);
299 	}
300 
301 	private static void ensureImpure() @nogc nothrow @safe
302 	{
303 	}
304 
305 	/// Iterates through all public methods in `ExtensionModule` annotated with the
306 	/// given UDA. For each matching function the callback paramter is called with
307 	/// the arguments being `(string name, Delegate symbol, UDA uda)`. `callback` is
308 	/// expected to return a boolean if the UDA values were a match.
309 	///
310 	/// Params:
311 	///  UDA = The UDA type to filter methods with. Methods can just have an UDA
312 	///     with this type and any values. See $(REF getUDAs, std.traits)
313 	///  callback = Called for every matching method. Expected to have 3 arguments
314 	///     being `(string name, Delegate symbol, UDA uda)` and returning `bool`
315 	///     telling if the uda values were a match or not. The Delegate is most
316 	///     often a function pointer to the given symbol and may differ between all
317 	///     calls.
318 	///
319 	///     If the UDA is a symbol and not a type (such as some enum manifest
320 	///     constant), then the UDA argument has no meaning and should not be used.
321 	///  returnFirst = if `true`, once callback returns `true` immediately return
322 	///     `true` for the whole function, otherwise `false`. If this is set to
323 	///     `false` then callback will be run on all symbols and this function
324 	///     returns `true` if any callback call has returned `true`.
325 	/// Returns: `true` if any callback returned `true`, `false` otherwise or if
326 	///     none were called. If `returnFirst` is set this function returns after
327 	///     the first successfull callback call.
328 	bool iterateExtensionMethodsByUDA(alias UDA, alias callback, bool returnFirst)()
329 	{
330 		bool found = false;
331 		foreach (name; ExtensionModule.members)
332 		{
333 			static if (__traits(compiles, __traits(getMember, ExtensionModule, name)))
334 			{
335 				// AliasSeq to workaround AliasSeq members
336 				alias symbols = AliasSeq!(__traits(getMember, ExtensionModule, name));
337 				static if (symbols.length == 1 && hasUDA!(symbols[0], UDA))
338 				{
339 					static assert (__traits(getOverloads, ExtensionModule, name, true).length == 1,
340 						"UDA @" ~ UDA.stringof ~ " annotated method " ~ name
341 						~ " has more than one overload, which is not supported. Please rename.");
342 					alias symbol = symbols[0];
343 					static if (isSomeFunction!(symbol) && __traits(getProtection, symbol) == "public")
344 					{
345 						static if (__traits(compiles, { enum uda = getUDAs!(symbol, UDA)[0]; }))
346 							enum uda = getUDAs!(symbol, UDA)[0];
347 						else
348 							enum uda = null;
349 
350 						static if (returnFirst)
351 						{
352 							if (callback(name, &symbol, uda))
353 								return true;
354 						}
355 						else
356 						{
357 							if (callback(name, &symbol, uda))
358 								found = true;
359 						}
360 					}
361 				}
362 			}
363 		}
364 
365 		return found;
366 	}
367 
368 	private T implParseParam(T)(string params)
369 	{
370 		import served.lsp.protocol;
371 
372 		try
373 		{
374 			if (params.length && params.ptr[0] == '[')
375 			{
376 				// positional parameter support
377 				// only supports passing a single argument
378 				string got;
379 				params.visitJsonArray!((item) {
380 					if (!got.length)
381 						got = item;
382 					else
383 						throw new Exception("Mismatched parameter count");
384 				});
385 				return got.deserializeJson!T;
386 			}
387 			else if (params.length && params.ptr[0] == '{')
388 			{
389 				// named parameter support
390 				// only supports passing structs (not parsing names of D method arguments)
391 				return params.deserializeJson!T;
392 			}
393 			else
394 			{
395 				// no parameters passed - parse empty JSON for the type or
396 				// use default value.
397 				static if (is(T == struct))
398 					return `{}`.deserializeJson!T;
399 				else
400 					return T.init;
401 			}
402 		}
403 		catch (Exception e)
404 		{
405 			ResponseError error;
406 			error.code = ErrorCode.invalidParams;
407 			error.message = "Failed converting input parameter `" ~ params ~ "` to needed type `" ~ T.stringof ~ "`: " ~ e.msg;
408 			error.data = JsonValue(e.toString);
409 			throw new MethodException(error);
410 		}
411 	}
412 
413 }