31
31
import com .github .shyiko .mysql .binlog .io .ByteArrayInputStream ;
32
32
import com .github .shyiko .mysql .binlog .jmx .BinaryLogClientMXBean ;
33
33
import com .github .shyiko .mysql .binlog .network .AuthenticationException ;
34
+ import com .github .shyiko .mysql .binlog .network .ClientCapabilities ;
35
+ import com .github .shyiko .mysql .binlog .network .DefaultSSLSocketFactory ;
36
+ import com .github .shyiko .mysql .binlog .network .SSLMode ;
37
+ import com .github .shyiko .mysql .binlog .network .SSLSocketFactory ;
34
38
import com .github .shyiko .mysql .binlog .network .ServerException ;
35
39
import com .github .shyiko .mysql .binlog .network .SocketFactory ;
40
+ import com .github .shyiko .mysql .binlog .network .TLSHostnameVerifier ;
36
41
import com .github .shyiko .mysql .binlog .network .protocol .ErrorPacket ;
37
42
import com .github .shyiko .mysql .binlog .network .protocol .GreetingPacket ;
38
43
import com .github .shyiko .mysql .binlog .network .protocol .Packet ;
44
49
import com .github .shyiko .mysql .binlog .network .protocol .command .DumpBinaryLogGtidCommand ;
45
50
import com .github .shyiko .mysql .binlog .network .protocol .command .PingCommand ;
46
51
import com .github .shyiko .mysql .binlog .network .protocol .command .QueryCommand ;
52
+ import com .github .shyiko .mysql .binlog .network .protocol .command .SSLRequestCommand ;
47
53
54
+ import javax .net .ssl .SSLContext ;
55
+ import javax .net .ssl .TrustManager ;
56
+ import javax .net .ssl .X509TrustManager ;
48
57
import java .io .EOFException ;
49
58
import java .io .IOException ;
50
59
import java .net .InetSocketAddress ;
51
60
import java .net .Socket ;
52
61
import java .net .SocketException ;
62
+ import java .security .GeneralSecurityException ;
63
+ import java .security .cert .CertificateException ;
64
+ import java .security .cert .X509Certificate ;
53
65
import java .util .Arrays ;
54
66
import java .util .Collections ;
55
67
import java .util .Iterator ;
74
86
*/
75
87
public class BinaryLogClient implements BinaryLogClientMXBean {
76
88
89
+ private static final SSLSocketFactory DEFAULT_REQUIRED_SSL_MODE_SOCKET_FACTORY = new DefaultSSLSocketFactory () {
90
+
91
+ @ Override
92
+ protected void initSSLContext (SSLContext sc ) throws GeneralSecurityException {
93
+ sc .init (null , new TrustManager []{
94
+ new X509TrustManager () {
95
+
96
+ @ Override
97
+ public void checkClientTrusted (X509Certificate [] x509Certificates , String s )
98
+ throws CertificateException { }
99
+
100
+ @ Override
101
+ public void checkServerTrusted (X509Certificate [] x509Certificates , String s )
102
+ throws CertificateException { }
103
+
104
+ @ Override
105
+ public X509Certificate [] getAcceptedIssuers () {
106
+ return new X509Certificate [0 ];
107
+ }
108
+ }
109
+ }, null );
110
+ }
111
+ };
112
+ private static final SSLSocketFactory DEFAULT_VERIFY_CA_SSL_MODE_SOCKET_FACTORY = new DefaultSSLSocketFactory ();
113
+
77
114
// https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
78
115
private static final int MAX_PACKET_LENGTH = 16777215 ;
79
116
@@ -90,6 +127,7 @@ public class BinaryLogClient implements BinaryLogClientMXBean {
90
127
private volatile String binlogFilename ;
91
128
private volatile long binlogPosition = 4 ;
92
129
private volatile long connectionId ;
130
+ private SSLMode sslMode = SSLMode .DISABLED ;
93
131
94
132
private GtidSet gtidSet ;
95
133
private final Object gtidSetAccessLock = new Object ();
@@ -100,6 +138,7 @@ public class BinaryLogClient implements BinaryLogClientMXBean {
100
138
private final List <LifecycleListener > lifecycleListeners = new LinkedList <LifecycleListener >();
101
139
102
140
private SocketFactory socketFactory ;
141
+ private SSLSocketFactory sslSocketFactory ;
103
142
104
143
private PacketChannel channel ;
105
144
private volatile boolean connected ;
@@ -166,6 +205,17 @@ public void setBlocking(boolean blocking) {
166
205
this .blocking = blocking ;
167
206
}
168
207
208
+ public SSLMode getSSLMode () {
209
+ return sslMode ;
210
+ }
211
+
212
+ public void setSSLMode (SSLMode sslMode ) {
213
+ if (sslMode == null ) {
214
+ throw new IllegalArgumentException ("SSL mode cannot be NULL" );
215
+ }
216
+ this .sslMode = sslMode ;
217
+ }
218
+
169
219
/**
170
220
* @return server id (65535 by default)
171
221
* @see #setServerId(long)
@@ -326,6 +376,13 @@ public void setSocketFactory(SocketFactory socketFactory) {
326
376
this .socketFactory = socketFactory ;
327
377
}
328
378
379
+ /**
380
+ * @param sslSocketFactory custom ssl socket factory
381
+ */
382
+ public void setSslSocketFactory (SSLSocketFactory sslSocketFactory ) {
383
+ this .sslSocketFactory = sslSocketFactory ;
384
+ }
385
+
329
386
/**
330
387
* @param threadFactory custom thread factory. If not provided, threads will be created using simple "new Thread()".
331
388
*/
@@ -357,7 +414,7 @@ public void connect() throws IOException {
357
414
". Please make sure it's running." , e );
358
415
}
359
416
greetingPacket = receiveGreeting ();
360
- authenticate (greetingPacket . getScramble (), greetingPacket . getServerCollation () );
417
+ authenticate (greetingPacket );
361
418
if (binlogFilename == null ) {
362
419
fetchBinlogFilenameAndPosition ();
363
420
}
@@ -446,10 +503,30 @@ private void ensureEventDataDeserializer(EventType eventType,
446
503
}
447
504
}
448
505
449
- private void authenticate (String salt , int collation ) throws IOException {
450
- AuthenticateCommand authenticateCommand = new AuthenticateCommand (schema , username , password , salt );
506
+ private void authenticate (GreetingPacket greetingPacket ) throws IOException {
507
+ int collation = greetingPacket .getServerCollation ();
508
+ int packetNumber = 1 ;
509
+ if (sslMode != SSLMode .DISABLED ) {
510
+ boolean serverSupportsSSL = (greetingPacket .getServerCapabilities () & ClientCapabilities .SSL ) != 0 ;
511
+ if (!serverSupportsSSL && (sslMode == SSLMode .REQUIRED || sslMode == SSLMode .VERIFY_CA ||
512
+ sslMode == SSLMode .VERIFY_IDENTITY )) {
513
+ throw new IOException ("MySQL server does not support SSL" );
514
+ }
515
+ if (serverSupportsSSL ) {
516
+ SSLRequestCommand sslRequestCommand = new SSLRequestCommand ();
517
+ sslRequestCommand .setCollation (collation );
518
+ channel .write (sslRequestCommand , packetNumber ++);
519
+ SSLSocketFactory sslSocketFactory = this .sslSocketFactory != null ? this .sslSocketFactory :
520
+ sslMode == SSLMode .REQUIRED ? DEFAULT_REQUIRED_SSL_MODE_SOCKET_FACTORY :
521
+ DEFAULT_VERIFY_CA_SSL_MODE_SOCKET_FACTORY ;
522
+ channel .upgradeToSSL (sslSocketFactory ,
523
+ sslMode == SSLMode .VERIFY_IDENTITY ? new TLSHostnameVerifier () : null );
524
+ }
525
+ }
526
+ AuthenticateCommand authenticateCommand = new AuthenticateCommand (schema , username , password ,
527
+ greetingPacket .getScramble ());
451
528
authenticateCommand .setCollation (collation );
452
- channel .write (authenticateCommand );
529
+ channel .write (authenticateCommand , packetNumber );
453
530
byte [] authenticationResult = channel .read ();
454
531
if (authenticationResult [0 ] != (byte ) 0x00 /* ok */ ) {
455
532
if (authenticationResult [0 ] == (byte ) 0xFF /* error */ ) {
0 commit comments