diff --git a/BililiveRecorder.Web/BasicAuthMiddleware.cs b/BililiveRecorder.Web/BasicAuthMiddleware.cs index fef69c9..fe723fa 100644 --- a/BililiveRecorder.Web/BasicAuthMiddleware.cs +++ b/BililiveRecorder.Web/BasicAuthMiddleware.cs @@ -2,6 +2,7 @@ using System; using System.IO; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.FileProviders; using Microsoft.Net.Http.Headers; @@ -11,21 +12,19 @@ namespace BililiveRecorder.Web { private readonly RequestDelegate next; private readonly ManifestEmbeddedFileProvider fileProvider; - private readonly BasicAuthCredential? credential; private const string BasicAndSpace = "Basic "; private static string? Html401Page; - public BasicAuthMiddleware(RequestDelegate next, ManifestEmbeddedFileProvider fileProvider, BasicAuthCredential? credential) + public BasicAuthMiddleware(RequestDelegate next, ManifestEmbeddedFileProvider fileProvider) { this.next = next ?? throw new ArgumentNullException(nameof(next)); this.fileProvider = fileProvider ?? throw new ArgumentNullException(nameof(fileProvider)); - this.credential = credential; } public Task InvokeAsync(HttpContext context) { - if (this.credential is null) + if (context.RequestServices.GetService() is not { } credential) { // 没有启用身份验证 return this.next(context); @@ -45,7 +44,7 @@ namespace BililiveRecorder.Web return this.ResponseWith401Async(context); } - if (this.credential.EncoededValue.Equals(requestCredential, StringComparison.Ordinal)) + if (credential.EncoededValue.Equals(requestCredential, StringComparison.Ordinal)) { return this.next(context); }