Skip to content

Commit ec5b489

Browse files
committed
feat: 优化上下文传递
1 parent 5a3060b commit ec5b489

File tree

8 files changed

+122
-21
lines changed

8 files changed

+122
-21
lines changed

hsweb-authorization/hsweb-authorization-api/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,11 @@
5252
<artifactId>jakarta.servlet-api</artifactId>
5353
<optional>true</optional>
5454
</dependency>
55+
56+
<dependency>
57+
<groupId>io.micrometer</groupId>
58+
<artifactId>context-propagation</artifactId>
59+
<optional>true</optional>
60+
</dependency>
5561
</dependencies>
5662
</project>

hsweb-authorization/hsweb-authorization-api/src/main/java/org/hswebframework/web/authorization/AuthenticationHolder.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ public static void addSupplier(AuthenticationSupplier supplier) {
109109
}
110110
}
111111

112+
public static void resetCurrent() {
113+
CURRENT.remove();
114+
}
115+
116+
public static void makeCurrent(Authentication authentication) {
117+
if (authentication == null) {
118+
resetCurrent();
119+
} else {
120+
CURRENT.set(authentication);
121+
}
122+
}
123+
112124
/**
113125
* 指定用户权限,执行一个任务。任务执行过程中可通过 {@link Authentication#current()}获取到当前权限.
114126
*
@@ -119,7 +131,7 @@ public static void addSupplier(AuthenticationSupplier supplier) {
119131
*/
120132
@SneakyThrows
121133
public static <T> T executeWith(Authentication current, Callable<T> callable) {
122-
Authentication previous = CURRENT.get();
134+
Authentication previous = CURRENT.getIfExists();
123135
try {
124136
CURRENT.set(current);
125137
return callable.call();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package org.hswebframework.web.authorization.context;
2+
3+
import io.micrometer.context.ThreadLocalAccessor;
4+
import org.hswebframework.web.authorization.Authentication;
5+
import org.hswebframework.web.authorization.AuthenticationHolder;
6+
7+
import javax.annotation.Nonnull;
8+
9+
public class AuthenticationThreadLocalAccessor implements ThreadLocalAccessor<Authentication> {
10+
11+
static final String KEY = "cp.hs.auth";
12+
13+
@Override
14+
@Nonnull
15+
public Object key() {
16+
return KEY;
17+
}
18+
19+
@Override
20+
public Authentication getValue() {
21+
return AuthenticationHolder.get().orElse(null);
22+
}
23+
24+
@Override
25+
public void setValue() {
26+
AuthenticationHolder.resetCurrent();
27+
}
28+
29+
@Override
30+
public void setValue(@Nonnull Authentication value) {
31+
AuthenticationHolder.makeCurrent(value);
32+
}
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
org.hswebframework.web.authorization.context.AuthenticationThreadLocalAccessor
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package org.hswebframework.web.authorization.context;
2+
3+
import org.hswebframework.web.authorization.Authentication;
4+
import org.hswebframework.web.authorization.AuthenticationHolder;
5+
import org.hswebframework.web.authorization.simple.SimpleAuthentication;
6+
import org.junit.jupiter.api.Test;
7+
import reactor.core.publisher.Hooks;
8+
import reactor.core.publisher.Mono;
9+
import reactor.core.scheduler.Schedulers;
10+
11+
import static org.junit.jupiter.api.Assertions.*;
12+
13+
class AuthenticationThreadLocalAccessorTest {
14+
15+
16+
@Test
17+
void testReadFromReactive() {
18+
19+
Hooks.enableAutomaticContextPropagation();
20+
21+
Authentication auth = new SimpleAuthentication();
22+
23+
Authentication auth2 = AuthenticationHolder.executeWith(
24+
auth,
25+
() -> Mono
26+
.fromCallable(() -> {
27+
// cross context
28+
return Authentication.current().orElse(null);
29+
})
30+
.subscribeOn(Schedulers.boundedElastic())
31+
.contextWrite(c->c)
32+
.block());
33+
34+
assertEquals(auth, auth2);
35+
36+
37+
}
38+
}

hsweb-authorization/hsweb-authorization-basic/src/main/java/org/hswebframework/web/authorization/basic/web/WebUserTokenInterceptor.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,21 @@ public WebUserTokenInterceptor(UserTokenManager userTokenManager,
3939
this.parser = definitionParser;
4040

4141
enableBasicAuthorization = userTokenParser
42-
.stream()
43-
.filter(UserTokenForTypeParser.class::isInstance)
44-
.anyMatch(parser -> "basic".equalsIgnoreCase(((UserTokenForTypeParser) parser).getTokenType()));
42+
.stream()
43+
.filter(UserTokenForTypeParser.class::isInstance)
44+
.anyMatch(parser -> "basic".equalsIgnoreCase(((UserTokenForTypeParser) parser).getTokenType()));
4545
}
4646

4747
@Override
4848
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
4949
List<ParsedToken> tokens = userTokenParser
50-
.stream()
51-
.map(parser -> parser.parseToken(request))
52-
.filter(Objects::nonNull)
53-
.collect(Collectors.toList());
50+
.stream()
51+
.map(parser -> parser.parseToken(request))
52+
.filter(Objects::nonNull)
53+
.toList();
5454

5555
if (tokens.isEmpty()) {
56-
if (enableBasicAuthorization && handler instanceof HandlerMethod) {
57-
HandlerMethod method = ((HandlerMethod) handler);
56+
if (enableBasicAuthorization && handler instanceof HandlerMethod method) {
5857
AuthorizeDefinition definition = parser.parse(method.getBeanType(), method.getMethod());
5958
if (null != definition) {
6059
response.addHeader("WWW-Authenticate", " Basic realm=\"\"");
@@ -69,12 +68,18 @@ public boolean preHandle(HttpServletRequest request, HttpServletResponse respons
6968
userToken = userTokenManager.getByToken(token).blockOptional().orElse(null);
7069
}
7170
if ((userToken == null || userToken.isExpired()) && parsedToken instanceof AuthorizedToken) {
72-
//先踢出旧token
73-
userTokenManager.signOutByToken(token).subscribe();
71+
userToken =
72+
userTokenManager
73+
.signOutByToken(token)
74+
.then(
75+
userTokenManager
76+
.signIn(parsedToken.getToken(),
77+
parsedToken.getType(),
78+
((AuthorizedToken) parsedToken).getUserId(),
79+
((AuthorizedToken) parsedToken)
80+
.getMaxInactiveInterval())
81+
)
7482

75-
userToken = userTokenManager
76-
.signIn(parsedToken.getToken(), parsedToken.getType(), ((AuthorizedToken) parsedToken).getUserId(), ((AuthorizedToken) parsedToken)
77-
.getMaxInactiveInterval())
7883
.block();
7984
}
8085
if (null != userToken) {
@@ -85,4 +90,8 @@ public boolean preHandle(HttpServletRequest request, HttpServletResponse respons
8590
return true;
8691
}
8792

93+
@Override
94+
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
95+
UserTokenHolder.setCurrent(null);
96+
}
8897
}

hsweb-core/src/main/java/org/hswebframework/web/context/ContextHolder.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ public static <T> T doInContext(Context context, Callable<T> call) {
4848
}
4949
}
5050

51-
public static <T> Mono<T> currentReactive(Function<ContextView, Mono<T>> handler) {
52-
return Mono.deferContextual(ctx -> handler.apply(current().putAll(ctx)));
51+
public static <T> Mono<T> wrap(Function<ContextView, Mono<T>> handler) {
52+
return Mono.deferContextual(ctx -> {
53+
Context context = current().putAll(ctx);
54+
return handler.apply(context);
55+
});
5356
}
5457

5558
public static Context current() {

hsweb-core/src/main/java/org/hswebframework/web/i18n/LocaleUtils.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import org.reactivestreams.Publisher;
88
import org.reactivestreams.Subscription;
99
import org.springframework.context.MessageSource;
10+
import org.springframework.context.i18n.LocaleContextHolder;
1011
import reactor.core.CoreSubscriber;
1112
import reactor.core.publisher.*;
1213
import reactor.util.context.Context;
@@ -93,10 +94,8 @@ public static String getMessage(Function<String, String> messageSource,
9394
*/
9495
public static Locale current() {
9596
Locale locale = CONTEXT_THREAD_LOCAL.get();
96-
if (locale == null) {
97-
locale = DEFAULT_LOCALE;
98-
}
99-
return locale;
97+
// fallback to spring
98+
return Objects.requireNonNullElseGet(locale, LocaleContextHolder::getLocale);
10099
}
101100

102101
/**

0 commit comments

Comments
 (0)