Skip to content

Commit

Permalink
Merge pull request #270 from turbofish-org/grpc-client
Browse files Browse the repository at this point in the history
Accept Fn() instead of function pointers in `ibc::start_grpc`
  • Loading branch information
keppel authored Oct 11, 2024
2 parents fe48101 + 33953f9 commit 73a3a56
Showing 1 changed file with 45 additions and 46 deletions.
91 changes: 45 additions & 46 deletions src/ibc/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl From<crate::Error> for tonic::Status {
}

pub struct IbcClientService<C> {
pub ibc: fn() -> C,
pub ibc: C,
}

#[tonic::async_trait]
Expand All @@ -144,9 +144,8 @@ impl<C: Client<IbcContext> + 'static> ClientQuery for IbcClientService<C> {
&self,
_request: Request<QueryClientStatesRequest>,
) -> Result<Response<QueryClientStatesResponse>, Status> {
let ibc = (self.ibc)();
let res = QueryClientStatesResponse {
client_states: ibc.query(|ibc| ibc.query_client_states()).await?,
client_states: self.ibc.query(|ibc| ibc.query_client_states()).await?,
..Default::default()
};
Ok(Response::new(res))
Expand All @@ -162,9 +161,8 @@ impl<C: Client<IbcContext> + 'static> ClientQuery for IbcClientService<C> {
let revision_number = request.revision_number;
let revision_height = request.revision_height;

let ibc = (self.ibc)();

let consensus_state = ibc
let consensus_state = self
.ibc
.query(|ibc| {
ibc.query_consensus_state(
client_id.clone().into(),
Expand All @@ -187,15 +185,15 @@ impl<C: Client<IbcContext> + 'static> ClientQuery for IbcClientService<C> {
&self,
request: Request<QueryConsensusStatesRequest>,
) -> Result<Response<QueryConsensusStatesResponse>, Status> {
let ibc = (self.ibc)();
let client_id: ClientId = request
.into_inner()
.client_id
.parse()
.map_err(|_| Status::invalid_argument("Invalid client ID".to_string()))?;

let res = QueryConsensusStatesResponse {
consensus_states: ibc
consensus_states: self
.ibc
.query(|ibc| ibc.query_consensus_states(client_id.clone().into()))
.await?,
..Default::default()
Expand All @@ -217,9 +215,9 @@ impl<C: Client<IbcContext> + 'static> ClientQuery for IbcClientService<C> {
let request = request.into_inner();
let client_id = ClientId::from_str(&request.client_id)
.map_err(|_| Status::invalid_argument("Invalid client ID".to_string()))?;
let ibc = (self.ibc)();

let client_status = ibc
let client_status = self
.ibc
.query(|ibc| ibc.query_client_status(client_id.clone().into()))
.await?;

Expand Down Expand Up @@ -253,7 +251,7 @@ impl<C: Client<IbcContext> + 'static> ClientQuery for IbcClientService<C> {
}

pub struct IbcConnectionService<C> {
ibc: fn() -> C,
ibc: C,
}

#[tonic::async_trait]
Expand All @@ -262,12 +260,12 @@ impl<C: Client<IbcContext> + 'static> ConnectionQuery for IbcConnectionService<C
&self,
request: Request<QueryConnectionRequest>,
) -> Result<Response<QueryConnectionResponse>, Status> {
let ibc = (self.ibc)();
let conn_id = ConnectionId::from_str(&request.into_inner().connection_id)
.map_err(|_| Status::invalid_argument("Invalid connection ID".to_string()))?;

Ok(Response::new(QueryConnectionResponse {
connection: ibc
connection: self
.ibc
.query(|ibc| ibc.query_connection(conn_id.clone().into()))
.await?
.map(Into::into),
Expand All @@ -279,9 +277,8 @@ impl<C: Client<IbcContext> + 'static> ConnectionQuery for IbcConnectionService<C
&self,
_request: Request<QueryConnectionsRequest>,
) -> Result<Response<QueryConnectionsResponse>, Status> {
let ibc = (self.ibc)();
Ok(Response::new(QueryConnectionsResponse {
connections: ibc.query(|ibc| ibc.query_all_connections()).await?,
connections: self.ibc.query(|ibc| ibc.query_all_connections()).await?,
..Default::default()
}))
}
Expand All @@ -290,13 +287,13 @@ impl<C: Client<IbcContext> + 'static> ConnectionQuery for IbcConnectionService<C
&self,
request: Request<QueryClientConnectionsRequest>,
) -> Result<Response<QueryClientConnectionsResponse>, Status> {
let ibc = (self.ibc)();
let client_id: ClientId = request
.into_inner()
.client_id
.parse()
.map_err(|_| Status::invalid_argument("Invalid client ID".to_string()))?;
let connection_ids = ibc
let connection_ids = self
.ibc
.query(|ibc| ibc.query_client_connections(client_id.clone().into()))
.await?;

Expand Down Expand Up @@ -329,7 +326,7 @@ impl<C: Client<IbcContext> + 'static> ConnectionQuery for IbcConnectionService<C
}

pub struct IbcChannelService<C> {
ibc: fn() -> C,
ibc: C,
revision_number: u64,
}

Expand All @@ -339,7 +336,6 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryChannelRequest>,
) -> Result<Response<QueryChannelResponse>, Status> {
let ibc = (self.ibc)();
let request = request.into_inner();
let port_id = PortId::from_str(&request.port_id)
.map_err(|_| Status::invalid_argument("invalid port id"))?;
Expand All @@ -349,7 +345,8 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
let path = ChannelEndPath(port_id, channel_id);

Ok(Response::new(QueryChannelResponse {
channel: ibc
channel: self
.ibc
.query(|ibc| ibc.query_channel(path.clone().into()))
.await?,
..Default::default()
Expand All @@ -360,9 +357,9 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
_request: Request<QueryChannelsRequest>,
) -> Result<Response<QueryChannelsResponse>, Status> {
let ibc = (self.ibc)();
let revision_number = self.revision_number;
let (channels, height) = ibc
let (channels, height) = self
.ibc
.query(|ibc| Ok((ibc.query_all_channels()?, ibc.height)))
.await?;

Expand All @@ -380,12 +377,12 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryConnectionChannelsRequest>,
) -> Result<Response<QueryConnectionChannelsResponse>, Status> {
let ibc = (self.ibc)();
let revision_number = self.revision_number;
let conn_id = ConnectionId::from_str(&request.get_ref().connection)
.map_err(|_| Status::invalid_argument("invalid connection id"))?;

let (channels, height) = ibc
let (channels, height) = self
.ibc
.query(|ibc| {
Ok((
ibc.query_connection_channels(conn_id.clone().into())?,
Expand All @@ -407,30 +404,32 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryChannelClientStateRequest>,
) -> Result<Response<QueryChannelClientStateResponse>, Status> {
let ibc = (self.ibc)();
let request = request.into_inner();
let port_id = PortId::from_str(&request.port_id)
.map_err(|_| Status::invalid_argument("invalid port id"))?;
let channel_id = ChannelId::from_str(&request.channel_id)
.map_err(|_| Status::invalid_argument("invalid channel id"))?;

let path = ChannelEndPath(port_id, channel_id);
let channel: Channel = ibc
let channel: Channel = self
.ibc
.query(|ibc| Ok(ibc.query_channel(path.clone().into())))
.await??
.ok_or_else(|| Status::not_found("channel not found"))?;
let connection_id = channel
.connection_hops
.first()
.ok_or_else(|| Status::not_found("channel does not have a connection hop"))?;
let connection_end: ConnectionEnd = ibc
let connection_end: ConnectionEnd = self
.ibc
.query(|ibc| {
Ok(ibc.query_connection(ConnectionId::from_str(connection_id).unwrap().into()))
})
.await??
.ok_or_else(|| Status::not_found("connection not found"))?;
let client_id = connection_end.client_id();
let client_state = ibc
let client_state = self
.ibc
.query(|ibc| {
Ok(ibc
.clients
Expand Down Expand Up @@ -473,7 +472,6 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryPacketCommitmentsRequest>,
) -> Result<Response<QueryPacketCommitmentsResponse>, Status> {
let ibc = (self.ibc)();
let revision_number = self.revision_number;
let request = request.into_inner();
let port_id = PortId::from_str(&request.port_id)
Expand All @@ -483,7 +481,8 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {

let path = PortChannel::new(port_id, channel_id);

let (commitments, height) = ibc
let (commitments, height) = self
.ibc
.query(|ibc| Ok((ibc.query_packet_commitments(path.clone())?, ibc.height)))
.await?;

Expand All @@ -508,10 +507,10 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
.map_err(|_| Status::invalid_argument("invalid channel id"))?;
let sequence = Sequence::from(request.sequence);

let ibc = (self.ibc)();
let receipt_path = ReceiptPath::new(&port_id, &channel_id, sequence);

let receipt = ibc
let receipt = self
.ibc
.query(|ibc| Ok(ibc.get_packet_receipt(&receipt_path.clone())?))
.await;

Expand All @@ -535,7 +534,6 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryPacketAcknowledgementsRequest>,
) -> Result<Response<QueryPacketAcknowledgementsResponse>, Status> {
let ibc = (self.ibc)();
let revision_number = self.revision_number;
let request = request.into_inner();
let port_id = PortId::from_str(&request.port_id)
Expand All @@ -545,7 +543,8 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
let sequences = request.packet_commitment_sequences;

let path = PortChannel::new(port_id, channel_id);
let (acknowledgements, height) = ibc
let (acknowledgements, height) = self
.ibc
.query(|ibc| {
Ok((
ibc.query_packet_acks(sequences.clone().try_into().unwrap(), path.clone())?,
Expand All @@ -568,7 +567,6 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryUnreceivedPacketsRequest>,
) -> Result<Response<QueryUnreceivedPacketsResponse>, Status> {
let ibc = (self.ibc)();
let revision_number = self.revision_number;
let request = request.into_inner();
let port_id = PortId::from_str(&request.port_id)
Expand All @@ -578,7 +576,8 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
let sequences_to_check: Vec<u64> = request.packet_commitment_sequences;
let path = PortChannel::new(port_id, channel_id);

let (sequences, height) = ibc
let (sequences, height) = self
.ibc
.query(|ibc| {
Ok((
ibc.query_unreceived_packets(
Expand All @@ -603,7 +602,6 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryUnreceivedAcksRequest>,
) -> Result<Response<QueryUnreceivedAcksResponse>, Status> {
let ibc = (self.ibc)();
let revision_number = self.revision_number;
let request = request.into_inner();
let port_id = PortId::from_str(&request.port_id)
Expand All @@ -613,7 +611,8 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
let sequences_to_check: Vec<u64> = request.packet_ack_sequences;
let path = PortChannel::new(port_id, channel_id);

let (sequences, height) = ibc
let (sequences, height) = self
.ibc
.query(|ibc| {
Ok((
ibc.query_unreceived_acks(
Expand All @@ -638,14 +637,14 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryNextSequenceReceiveRequest>,
) -> Result<Response<QueryNextSequenceReceiveResponse>, Status> {
let ibc_client = (self.ibc)();
let request_inner = request.into_inner();
let port_id = PortId::from_str(&request_inner.port_id)
.map_err(|_| Status::invalid_argument("invalid port id"))?;
let channel_id = ChannelId::from_str(&request_inner.channel_id)
.map_err(|_| Status::invalid_argument("invalid channel id"))?;
let res = QueryNextSequenceReceiveResponse {
next_sequence_receive: ibc_client
next_sequence_receive: self
.ibc
.query(|ibc| {
ibc.query_next_sequence_receive(PortChannel::new(
port_id.clone(),
Expand All @@ -662,14 +661,14 @@ impl<C: Client<IbcContext> + 'static> ChannelQuery for IbcChannelService<C> {
&self,
request: Request<QueryNextSequenceSendRequest>,
) -> Result<Response<QueryNextSequenceSendResponse>, Status> {
let ibc_client = (self.ibc)();
let request_inner = request.into_inner();
let port_id = PortId::from_str(&request_inner.port_id)
.map_err(|_| Status::invalid_argument("invalid port id"))?;
let channel_id = ChannelId::from_str(&request_inner.channel_id)
.map_err(|_| Status::invalid_argument("invalid channel id"))?;
let res = QueryNextSequenceSendResponse {
next_sequence_send: ibc_client
next_sequence_send: self
.ibc
.query(|ibc| {
ibc.query_next_sequence_send(PortChannel::new(
port_id.clone(),
Expand Down Expand Up @@ -1157,20 +1156,20 @@ pub struct GrpcOpts {
}

/// Start the gRPC server.
pub async fn start_grpc<C: Client<IbcContext> + 'static>(client: fn() -> C, opts: &GrpcOpts) {
pub async fn start_grpc<C: Client<IbcContext> + 'static, F: Fn() -> C>(client: F, opts: &GrpcOpts) {
use tonic::transport::Server;
let auth_service = AuthQueryServer::new(AuthService {});
let bank_service = BankQueryServer::new(BankService {});
let staking_service = StakingQueryServer::new(StakingService {});
let ibc_client_service = ClientQueryServer::new(IbcClientService { ibc: client });
let ibc_connection_service = ConnectionQueryServer::new(IbcConnectionService { ibc: client });
let ibc_client_service = ClientQueryServer::new(IbcClientService { ibc: client() });
let ibc_connection_service = ConnectionQueryServer::new(IbcConnectionService { ibc: client() });
let revision_number = opts
.chain_id
.rsplit_once('-')
.map(|(_, n)| n.parse::<u64>().unwrap_or(0))
.unwrap_or(0);
let ibc_channel_service = ChannelQueryServer::new(IbcChannelService {
ibc: client,
ibc: client(),
revision_number,
});
let health_service = HealthServer::new(AppHealthService {});
Expand Down

0 comments on commit 73a3a56

Please sign in to comment.