diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java index 2c903ce5270..083c599d415 100644 --- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java @@ -69,7 +69,7 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher; import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher; import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher; -import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; +import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor; import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; import org.springframework.util.AntPathMatcher; import org.springframework.util.Assert; @@ -364,7 +364,7 @@ else if (CSRF_HANDSHAKE_HANDLER_CLASSES.contains(beanClassName)) { ManagedList interceptors = new ManagedList(); interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class)); if (!this.sameOriginDisabled) { - interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); + interceptors.add(new RootBeanDefinition(XorCsrfChannelInterceptor.class)); } interceptors.add(registry.getBeanDefinition(this.inboundSecurityInterceptorId)); BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index 6e999933a28..820d2364ae3 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -20,6 +20,7 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import java.util.Base64; import java.util.HashMap; import java.util.Map; import java.util.function.Supplier; @@ -97,6 +98,13 @@ public class WebSocketMessageBrokerConfigTests { private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests"; + /* + * Token format: "token" length random pad bytes + "token" (each byte UTF8 ^= 1). + */ + private static final byte[] XOR_CSRF_TOKEN_BYTES = new byte[] { 1, 1, 1, 1, 1, 117, 110, 106, 100, 111 }; + + private static final String XOR_CSRF_TOKEN_VALUE = Base64.getEncoder().encodeToString(XOR_CSRF_TOKEN_BYTES); + public final SpringTestContext spring = new SpringTestContext(this); @Autowired(required = false) @@ -125,7 +133,7 @@ public void sendWhenNoIdSpecifiedThenIntegratesWithClientInboundChannel() { public void sendWhenAnonymousMessageWithConnectMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); - headers.setNativeHeader(this.token.getHeaderName(), this.token.getToken()); + headers.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); this.clientInboundChannel.send(message("/permitAll", headers)); } @@ -197,7 +205,7 @@ public void sendWhenNoIdSpecifiedThenIntegratesWithAuthorizationManager() { public void sendWhenAnonymousMessageWithConnectMessageTypeThenAuthorizationManagerPermits() { this.spring.configLocations(xml("NoIdAuthorizationManager")).autowire(); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); - headers.setNativeHeader(this.token.getHeaderName(), this.token.getToken()); + headers.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); this.clientInboundChannel.send(message("/permitAll", headers)); }