Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use bitcoin::consensus::encode::{deserialize, serialize};
use bitcoin::{block, Script, Transaction, Txid};

use crate::batch::Batch;
use crate::raw_client::{CLIENT_NAME, PROTOCOL_VERSION_MAX, PROTOCOL_VERSION_MIN};
use crate::types::*;

impl<E: Deref> ElectrumApi for E
Expand Down Expand Up @@ -181,6 +182,10 @@ where
(**self).server_features()
}

fn protocol_version(&self) -> Result<String, Error> {
(**self).protocol_version()
}

fn mempool_get_info(&self) -> Result<MempoolInfoRes, Error> {
(**self).mempool_get_info()
}
Expand Down Expand Up @@ -444,6 +449,29 @@ pub trait ElectrumApi {
/// Returns the capabilities of the server.
fn server_features(&self) -> Result<ServerFeaturesRes, Error>;

/// Returns the negotiated Electrum protocol version.
///
/// Clients that already negotiated a protocol version during connection setup
/// should return that cached value. Implementors that do not cache it can use
/// this default implementation, which retrieves the version with
/// `server.version`.
fn protocol_version(&self) -> Result<String, Error> {
let version_range = vec![
PROTOCOL_VERSION_MIN.to_string(),
PROTOCOL_VERSION_MAX.to_string(),
];
let result = self.raw_call(
"server.version",
vec![
Param::String(CLIENT_NAME.to_string()),
Param::StringVec(version_range),
],
)?;
let response: ServerVersionRes = serde_json::from_value(result)?;

Ok(response.protocol_version)
}

/// Returns information about the current state of the mempool.
///
/// This method was added in protocol v1.6 and replaces `relay_fee` by providing
Expand Down
5 changes: 5 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,11 @@ impl ElectrumApi for Client {
impl_inner_call!(self, server_features)
}

#[inline]
fn protocol_version(&self) -> Result<String, Error> {
impl_inner_call!(self, protocol_version)
}

#[inline]
fn mempool_get_info(&self) -> Result<MempoolInfoRes, Error> {
impl_inner_call!(self, mempool_get_info)
Expand Down
118 changes: 114 additions & 4 deletions src/raw_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,13 @@ impl<S: Read + Write> RawClient<S> {
///
/// [`ClientType`]: crate::ClientType
fn negotiate_protocol_version(self) -> Result<Self, Error> {
let response = self.request_server_version()?;

self.cache_protocol_version(response.protocol_version)?;
Ok(self)
}

fn request_server_version(&self) -> Result<ServerVersionRes, Error> {
let version_range = vec![
PROTOCOL_VERSION_MIN.to_string(),
PROTOCOL_VERSION_MAX.to_string(),
Expand All @@ -653,10 +660,14 @@ impl<S: Read + Write> RawClient<S> {
],
);
let result = self.call(req)?;
let response: ServerVersionRes = serde_json::from_value(result)?;

*self.protocol_version.lock()? = Some(response.protocol_version);
Ok(self)
Ok(serde_json::from_value(result)?)
}

fn cache_protocol_version(&self, protocol_version: String) -> Result<String, Error> {
*self.protocol_version.lock()? = Some(protocol_version.clone());

Ok(protocol_version)
}

fn _reader_thread(&self, until_message: Option<usize>) -> Result<serde_json::Value, Error> {
Expand Down Expand Up @@ -1378,6 +1389,16 @@ impl<T: Read + Write> ElectrumApi for RawClient<T> {
Ok(serde_json::from_value(result)?)
}

fn protocol_version(&self) -> Result<String, Error> {
if let Some(protocol_version) = self.protocol_version.lock()?.clone() {
return Ok(protocol_version);
}

let response = self.request_server_version()?;

self.cache_protocol_version(response.protocol_version)
}

fn mempool_get_info(&self) -> Result<MempoolInfoRes, Error> {
let req = Request::new_id(
self.last_id.fetch_add(1, Ordering::SeqCst),
Expand Down Expand Up @@ -1407,7 +1428,11 @@ impl<T: Read + Write> ElectrumApi for RawClient<T> {

#[cfg(test)]
mod test {
use std::str::FromStr;
use std::{
io::{self, Cursor, Read, Write},
str::FromStr,
sync::{Arc, Mutex},
};

use crate::utils;

Expand All @@ -1421,6 +1446,60 @@ mod test {
// here's an useful list of live servers: https://1209k.com/bitcoin-eye/ele.php.
const DEFAULT_TEST_ELECTRUM_SERVER: &str = "fortress.qtornado.com:443";

#[derive(Clone)]
struct MockStream {
responses: Arc<Mutex<Cursor<Vec<u8>>>>,
requests: Arc<Mutex<Vec<u8>>>,
}

impl MockStream {
fn new(responses: impl Into<Vec<u8>>) -> Self {
Self {
responses: Arc::new(Mutex::new(Cursor::new(responses.into()))),
requests: Arc::new(Mutex::new(Vec::new())),
}
}

fn written_requests(&self) -> Vec<serde_json::Value> {
let requests = self.requests.lock().unwrap().clone();
let requests = String::from_utf8(requests).unwrap();

requests
.lines()
.map(|line| serde_json::from_str(line).unwrap())
.collect()
}
}

impl Read for MockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.responses.lock().unwrap().read(buf)
}
}

impl Write for MockStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.requests.lock().unwrap().extend_from_slice(buf);

Ok(buf.len())
}

fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}

fn server_version_response(id: usize, protocol_version: &str) -> String {
format!(
r#"{{"jsonrpc":"2.0","id":{id},"result":["ElectrumX 1.18.0","{protocol_version}"]}}"#
) + "\n"
}

fn assert_server_version_request(request: &serde_json::Value) {
assert_eq!(request["method"], "server.version");
assert_eq!(request["params"], serde_json::json!(["", ["1.4", "1.6"]]));
}

fn get_test_auth_client(
authorization_provider: Option<AuthProvider>,
) -> RawClient<ElectrumSslStream> {
Expand All @@ -1439,6 +1518,37 @@ mod test {
.expect("should build the `RawClient` successfully!")
}

#[test]
fn test_protocol_version_returns_negotiated_version_without_new_request() {
let stream = MockStream::new(server_version_response(0, "1.6"));
let stream_handle = stream.clone();
let client = RawClient::from(stream)
.negotiate_protocol_version()
.unwrap();

assert_eq!(client.protocol_version().unwrap(), "1.6");
assert_eq!(client.calls_made().unwrap(), 1);

let requests = stream_handle.written_requests();
assert_eq!(requests.len(), 1);
assert_server_version_request(&requests[0]);
}

#[test]
fn test_protocol_version_fetches_and_caches_missing_version() {
let stream = MockStream::new(server_version_response(0, "1.6"));
let stream_handle = stream.clone();
let client = RawClient::from(stream);

assert_eq!(client.protocol_version().unwrap(), "1.6");
assert_eq!(client.protocol_version().unwrap(), "1.6");
assert_eq!(client.calls_made().unwrap(), 1);

let requests = stream_handle.written_requests();
assert_eq!(requests.len(), 1);
assert_server_version_request(&requests[0]);
}

#[test]
fn test_server_features_simple() {
let client = get_test_client();
Expand Down
Loading