Skip to content

Commit

Permalink
Support custom protocol between proxy and broker (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
Technoboy- authored Apr 18, 2022
1 parent 80c36a6 commit ff62db2
Show file tree
Hide file tree
Showing 39 changed files with 1,226 additions and 334 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttMessageBuilders;
import io.netty.handler.timeout.IdleStateHandler;
import io.streamnative.pulsar.handlers.mqtt.adapter.MqttAdapterMessage;
import io.streamnative.pulsar.handlers.mqtt.exception.restrictions.InvalidSessionExpireIntervalException;
import io.streamnative.pulsar.handlers.mqtt.messages.codes.mqtt5.Mqtt5DisConnReasonCode;
import io.streamnative.pulsar.handlers.mqtt.restrictions.ClientRestrictions;
Expand Down Expand Up @@ -74,9 +75,13 @@ public class Connection {
private final ClientRestrictions clientRestrictions;
@Getter
private final ServerRestrictions serverRestrictions;
volatile ConnectionState connectionState = DISCONNECTED;
@Getter
private volatile int serverCurrentReceiveCounter = 0;
@Getter
private final ProtocolMethodProcessor processor;
@Getter
private final boolean adapter;
private volatile ConnectionState connectionState = DISCONNECTED;
private final PulsarEventCenter eventCenter;
private final List<PulsarEventListener> listeners;

Expand All @@ -98,11 +103,13 @@ public class Connection {
this.channel = builder.channel;
this.manager = builder.connectionManager;
this.connectMessage = builder.connectMessage;
this.ackHandler = AckHandlerFactory.of(protocolVersion).getAckHandler();
this.ackHandler = AckHandlerFactory.newAckHandler(this);
this.channel.attr(ATTR_KEY_CONNECTION).set(this);
this.topicSubscriptionManager = new TopicSubscriptionManager();
this.addIdleStateHandler();
this.eventCenter = builder.eventCenter;
this.processor = builder.processor;
this.adapter = builder.adapter;
this.listeners = Collections.synchronizedList(new ArrayList<>());
this.manager.addConnection(this);
}
Expand All @@ -117,41 +124,57 @@ private void addIdleStateHandler() {
}

public ChannelFuture sendConnAck() {
return ackHandler.sendConnAck(this);
return ackHandler.sendConnAck();
}

public ChannelFuture send(MqttMessage mqttMessage) {
public ChannelFuture send(MqttAdapterMessage adapterMessage) {
adapterMessage.setAdapter(isAdapter());
if (!channel.isActive()) {
log.error("send mqttMessage : {} failed due to channel is inactive.", mqttMessage);
log.error("send mqttMessage : {} failed due to channel is inactive.", adapterMessage);
return channel.newFailedFuture(channelInactiveException);
}
return channel.writeAndFlush(mqttMessage).addListener(future -> {
return channel.writeAndFlush(adapterMessage).addListener(future -> {
if (!future.isSuccess()) {
log.error("send mqttMessage : {} failed", mqttMessage, future.cause());
log.error("send mqttMessage : {} failed", adapterMessage, future.cause());
}
});
}

public ChannelFuture sendThenClose(MqttMessage mqttMessage) {
public ChannelFuture sendThenClose(MqttAdapterMessage adapterMessage) {
adapterMessage.setAdapter(isAdapter());
if (!channel.isActive()) {
log.error("send mqttMessage : {} failed due to channel is inactive.", mqttMessage);
log.error("send mqttMessage : {} failed due to channel is inactive.", adapterMessage);
return channel.newFailedFuture(channelInactiveException);
}
channel.writeAndFlush(mqttMessage).addListener(future -> {
ChannelFuture channelFuture = channel.writeAndFlush(adapterMessage).addListener(future -> {
if (!future.isSuccess()) {
log.error("send mqttMessage : {} failed", mqttMessage, future.cause());
log.error("send mqttMessage : {} failed", adapterMessage, future.cause());
}
});
return channel.close();
if (isAdapter()) {
disconnect();
} else {
channel.close();
}
return channelFuture;
}

/**
* Broker send disconnect.
*/
public void disconnect() {
if (MqttUtils.isMqtt5(protocolVersion)) {
if (MqttUtils.isMqtt5(protocolVersion) || isAdapter()) {
MqttMessage mqttMessage = MqttMessageBuilders
.disconnect()
.reasonCode(Mqtt5DisConnReasonCode.SESSION_TAKEN_OVER.byteValue())
.build();
sendThenClose(mqttMessage);
MqttAdapterMessage adapterMsg = new MqttAdapterMessage(this.clientId, mqttMessage);
if (isAdapter()) {
send(adapterMsg);
processor.processConnectionLost();
} else {
sendThenClose(adapterMsg);
}
} else {
channel.close();
}
Expand Down Expand Up @@ -279,6 +302,8 @@ public static class ConnectionBuilder {
private ClientRestrictions clientRestrictions;
private ServerRestrictions serverRestrictions;
private PulsarEventCenter eventCenter;
private ProtocolMethodProcessor processor;
private boolean adapter;

public ConnectionBuilder protocolVersion(int protocolVersion) {
this.protocolVersion = protocolVersion;
Expand Down Expand Up @@ -320,6 +345,11 @@ public ConnectionBuilder connectionManager(MQTTConnectionManager connectionManag
return this;
}

public ConnectionBuilder processor(ProtocolMethodProcessor processor) {
this.processor = processor;
return this;
}

public ConnectionBuilder connectMessage(MqttConnectMessage connectMessage) {
this.connectMessage = connectMessage;
return this;
Expand All @@ -330,6 +360,11 @@ public ConnectionBuilder eventCenter(PulsarEventCenter eventCenter) {
return this;
}

public ConnectionBuilder adapter(boolean isAdapter) {
this.adapter = isAdapter;
return this;
}

public Connection build() {
return new Connection(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
public final class Constants {

public static final String DEFAULT_CLIENT_ID = "__MoPInternalClientId";
public static final String ATTR_CONNECTION = "Connection";
public static final String ATTR_CLIENT_ADDR = "ClientAddr";
public static final String AUTH_BASIC = "basic";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.codec.mqtt.MqttEncoder;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.streamnative.pulsar.handlers.mqtt.adapter.CombineAdapterHandler;
import io.streamnative.pulsar.handlers.mqtt.adapter.MqttAdapterDecoder;
import io.streamnative.pulsar.handlers.mqtt.adapter.MqttAdapterEncoder;
import io.streamnative.pulsar.handlers.mqtt.support.psk.PSKUtils;
import org.apache.pulsar.common.util.NettyServerSslContextBuilder;
import org.apache.pulsar.common.util.SslContextAutoRefreshBuilder;
Expand Down Expand Up @@ -94,8 +96,13 @@ public void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast(TLS_HANDLER,
new SslHandler(PSKUtils.createServerEngine(ch, mqttService.getPskConfiguration())));
}
ch.pipeline().addLast("decoder", new MqttDecoder(mqttConfig.getMqttMessageMaxLength()));
ch.pipeline().addLast("encoder", MqttEncoder.INSTANCE);
ch.pipeline().addLast("handler", new MQTTInboundHandler(mqttService));
// Decoder
ch.pipeline().addLast(MqttAdapterDecoder.NAME, new MqttAdapterDecoder());
ch.pipeline().addLast("mqtt-decoder", new MqttDecoder(mqttConfig.getMqttMessageMaxLength()));
// Encoder
ch.pipeline().addLast(MqttAdapterEncoder.NAME, MqttAdapterEncoder.INSTANCE);
// Handler
ch.pipeline().addLast(CombineAdapterHandler.NAME, new CombineAdapterHandler());
ch.pipeline().addLast(MQTTInboundHandler.NAME, new MQTTInboundHandler(mqttService));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,21 @@
package io.streamnative.pulsar.handlers.mqtt;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.checkState;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttMessageType;
import io.netty.handler.codec.mqtt.MqttPubAckMessage;
import io.netty.handler.codec.mqtt.MqttPublishMessage;
import io.netty.handler.codec.mqtt.MqttSubscribeMessage;
import io.netty.handler.codec.mqtt.MqttUnsubscribeMessage;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.ReferenceCountUtil;
import io.streamnative.pulsar.handlers.mqtt.adapter.MqttAdapterMessage;
import io.streamnative.pulsar.handlers.mqtt.support.MQTTBrokerProtocolMethodProcessor;
import io.streamnative.pulsar.handlers.mqtt.utils.NettyUtils;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;

/**
Expand All @@ -39,60 +38,69 @@
@Slf4j
public class MQTTCommonInboundHandler extends ChannelInboundHandlerAdapter {

protected ProtocolMethodProcessor processor;
public static final String NAME = "InboundHandler";

@Setter
protected MQTTService mqttService;

protected final ConcurrentHashMap<String, ProtocolMethodProcessor> processors = new ConcurrentHashMap<>();

@Override
public void channelRead(ChannelHandlerContext ctx, Object message) {
checkArgument(message instanceof MqttMessage);
checkNotNull(processor);
MqttMessage msg = (MqttMessage) message;
checkArgument(message instanceof MqttAdapterMessage);
MqttAdapterMessage adapterMsg = (MqttAdapterMessage) message;
MqttMessage mqttMessage = adapterMsg.getMqttMessage();
ProtocolMethodProcessor processor = processors.computeIfAbsent(adapterMsg.getClientId(), key -> {
MQTTBrokerProtocolMethodProcessor p = new MQTTBrokerProtocolMethodProcessor(mqttService, ctx);
CompletableFuture<Void> inactiveFuture = p.getInactiveFuture();
inactiveFuture.whenComplete((id, ex) -> {
processors.remove(adapterMsg.getClientId());
});
return p;
});
try {
checkState(msg);
MqttMessageType messageType = msg.fixedHeader().messageType();
checkState(mqttMessage);
MqttMessageType messageType = mqttMessage.fixedHeader().messageType();
if (log.isDebugEnabled()) {
log.debug("Processing MQTT Inbound handler message, type={}", messageType);
log.debug("Inbound handler read message : type={}, clientId : {} adapter : {}", messageType,
adapterMsg.getClientId(), adapterMsg.isAdapter());
}
switch (messageType) {
case CONNECT:
checkArgument(msg instanceof MqttConnectMessage);
processor.processConnect((MqttConnectMessage) msg);
processor.processConnect(adapterMsg);
break;
case SUBSCRIBE:
checkArgument(msg instanceof MqttSubscribeMessage);
processor.processSubscribe((MqttSubscribeMessage) msg);
processor.processSubscribe(adapterMsg);
break;
case UNSUBSCRIBE:
checkArgument(msg instanceof MqttUnsubscribeMessage);
processor.processUnSubscribe((MqttUnsubscribeMessage) msg);
processor.processUnSubscribe(adapterMsg);
break;
case PUBLISH:
checkArgument(msg instanceof MqttPublishMessage);
processor.processPublish((MqttPublishMessage) msg);
processor.processPublish(adapterMsg);
break;
case PUBREC:
processor.processPubRec(msg);
processor.processPubRec(adapterMsg);
break;
case PUBCOMP:
processor.processPubComp(msg);
processor.processPubComp(adapterMsg);
break;
case PUBREL:
processor.processPubRel(msg);
processor.processPubRel(adapterMsg);
break;
case DISCONNECT:
processor.processDisconnect(msg);
processor.processDisconnect(adapterMsg);
break;
case PUBACK:
checkArgument(msg instanceof MqttPubAckMessage);
processor.processPubAck((MqttPubAckMessage) msg);
processor.processPubAck(adapterMsg);
break;
case PINGREQ:
processor.processPingReq();
processor.processPingReq(adapterMsg);
break;
default:
throw new UnsupportedOperationException("Unknown MessageType: " + messageType);
}
} catch (Throwable ex) {
ReferenceCountUtil.safeRelease(msg);
ReferenceCountUtil.safeRelease(mqttMessage);
log.error("Exception was caught while processing MQTT message, ", ex);
ctx.close();
}
Expand All @@ -105,7 +113,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
processor.processConnectionLost();
processors.values().forEach(ProtocolMethodProcessor::processConnectionLost);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.streamnative.pulsar.handlers.mqtt.support.DefaultProtocolMethodProcessorImpl;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
/**
* MQTT in bound handler.
Expand All @@ -25,16 +23,14 @@
@Slf4j
public class MQTTInboundHandler extends MQTTCommonInboundHandler {

@Getter
private final MQTTService mqttService;
public static final String NAME = "handler";

public MQTTInboundHandler(MQTTService mqttService) {
this.mqttService = mqttService;
super.mqttService = mqttService;
}

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
processor = new DefaultProtocolMethodProcessorImpl(mqttService, ctx);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ public boolean addSubscriptions(String clientId, List<MqttTopicSubscription> top
return duplicated;
}

public List<Pair<String, String>> findMatchTopic(String topic) {
/**
* Find the matched topic from the subscriptions.
* @param topic
* @return Pair with clientId, topicName.
*/
public List<Pair<String, String>> findMatchedTopic(String topic) {
List<Pair<String, String>> result = new ArrayList<>();
Set<Map.Entry<String, List<MqttTopicSubscription>>> entries = subscriptions.entrySet();
for (Map.Entry<String, List<MqttTopicSubscription>> entry : entries) {
Expand Down
Loading

0 comments on commit ff62db2

Please sign in to comment.