diff --git a/java/google/registry/proxy/handler/ProxyProtocolHandler.java b/java/google/registry/proxy/handler/ProxyProtocolHandler.java index dfa06c3f0..0856d1665 100644 --- a/java/google/registry/proxy/handler/ProxyProtocolHandler.java +++ b/java/google/registry/proxy/handler/ProxyProtocolHandler.java @@ -17,6 +17,7 @@ package google.registry.proxy.handler; import static com.google.common.base.Preconditions.checkState; import static java.nio.charset.StandardCharsets.US_ASCII; +import google.registry.util.FormattingLogger; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; @@ -55,6 +56,8 @@ public class ProxyProtocolHandler extends ByteToMessageDecoder { public static final AttributeKey REMOTE_ADDRESS_KEY = AttributeKey.valueOf("REMOTE_ADDRESS_KEY"); + private static final FormattingLogger logger = FormattingLogger.getLoggerForCallerClass(); + // The proxy header must start with this prefix. // Sample header: "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n". private static final byte[] HEADER_PREFIX = "PROXY".getBytes(US_ASCII); @@ -69,20 +72,40 @@ public class ProxyProtocolHandler extends ByteToMessageDecoder { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { super.channelRead(ctx, msg); if (finished) { + String remoteIP; if (proxyHeader != null) { - ctx.channel().attr(REMOTE_ADDRESS_KEY).set(proxyHeader.split(" ")[2]); - } else { - SocketAddress remoteAddress = ctx.channel().remoteAddress(); - if (remoteAddress instanceof InetSocketAddress) { - ctx.channel() - .attr(REMOTE_ADDRESS_KEY) - .set(((InetSocketAddress) remoteAddress).getAddress().getHostAddress()); + logger.finefmt("PROXIED CONNECTION: %s", ctx.channel()); + logger.finefmt("PROXY HEADER: %s", proxyHeader); + String[] headerArray = proxyHeader.split(" ", -1); + if (headerArray.length == 6) { + remoteIP = headerArray[2]; + logger.finefmt("Header parsed, using %s as remote IP.", remoteIP); + } else { + logger.finefmt("Cannot parse the header, use source IP as a last resort."); + remoteIP = getSourceIP(ctx); } + } else { + logger.finefmt("No header present, using source IP directly."); + remoteIP = getSourceIP(ctx); } + if (remoteIP != null) { + ctx.channel().attr(REMOTE_ADDRESS_KEY).set(remoteIP); + } else { + logger.warningfmt("Not able to obtain remote IP for %s", ctx.channel()); + } + // ByteToMessageDecoder automatically flushes unread bytes in the ByteBuf to the next handler + // when itself is being removed. ctx.pipeline().remove(this); } } + private static String getSourceIP(ChannelHandlerContext ctx) { + SocketAddress remoteAddress = ctx.channel().remoteAddress(); + return (remoteAddress instanceof InetSocketAddress) + ? ((InetSocketAddress) remoteAddress).getAddress().getHostAddress() + : null; + } + /** * Attempts to decode an internally accumulated buffer and find the proxy protocol header. * @@ -115,7 +138,9 @@ public class ProxyProtocolHandler extends ByteToMessageDecoder { if (eol >= 0) { // ByteBuf.readBytes is called so that the header is processed and not passed to handlers // further in the pipeline. - proxyHeader = in.readBytes(eol).toString(US_ASCII); + byte[] headerBytes = new byte[eol]; + in.readBytes(headerBytes); + proxyHeader = new String(headerBytes, US_ASCII); // Skip \r\n. in.skipBytes(2); // Proxy header processed, mark finished so that this handler is removed. diff --git a/javatests/google/registry/proxy/handler/ProxyProtocolHandlerTest.java b/javatests/google/registry/proxy/handler/ProxyProtocolHandlerTest.java index 5e39419f4..0b93da26f 100644 --- a/javatests/google/registry/proxy/handler/ProxyProtocolHandlerTest.java +++ b/javatests/google/registry/proxy/handler/ProxyProtocolHandlerTest.java @@ -49,6 +49,20 @@ public class ProxyProtocolHandlerTest { assertThat(channel.isActive()).isTrue(); } + @Test + public void testSuccess_proxyHeaderMalformed_singleFrame() { + header = String.format("PROXY UNKNOWN\r\n"); + String message = "some message"; + // Header processed, rest of the message passed along. + assertThat(channel.writeInbound(Unpooled.wrappedBuffer((header + message).getBytes(UTF_8)))) + .isTrue(); + assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo(message); + // Header malformed. + assertThat(channel.attr(REMOTE_ADDRESS_KEY).get()).isNull(); + assertThat(channel.pipeline().get(ProxyProtocolHandler.class)).isNull(); + assertThat(channel.isActive()).isTrue(); + } + @Test public void testSuccess_proxyHeaderPresent_multipleFrames() { header = String.format(HEADER_TEMPLATE, 4, "172.0.0.1", "255.255.255.255", "234", "123");