Skip to content

Commit

Permalink
fix: fix network switch dapp
Browse files Browse the repository at this point in the history
  • Loading branch information
salimtb committed Feb 7, 2025
1 parent 63afa20 commit 6875a8f
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ export function validateAddEthereumChainParams(params) {
* @param {Function} hooks.getCaveat - The callback to get the CAIP-25 caveat for the origin.
* @param {Function} hooks.requestPermittedChainsPermissionForOrigin - The callback to request a new permittedChains-equivalent CAIP-25 permission.
* @param {Function} hooks.requestPermittedChainsPermissionIncrementalForOrigin - The callback to add a new chain to the permittedChains-equivalent CAIP-25 permission.
* @param {Function} hooks.setTokenNetworkFilter - The callback to set the token network filter.
* @returns a null response on success or an error if user rejects an approval when autoApprove is false or on unexpected errors.
*/
export async function switchChain(
Expand All @@ -183,6 +184,7 @@ export async function switchChain(
getCaveat,
requestPermittedChainsPermissionForOrigin,
requestPermittedChainsPermissionIncrementalForOrigin,
setTokenNetworkFilter,
},
) {
try {
Expand All @@ -208,6 +210,7 @@ export async function switchChain(
}

await setActiveNetwork(networkClientId);
setTokenNetworkFilter(chainId);
response.result = null;
return end();
} catch (error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ describe('Ethereum Chain Utils', () => {
getCaveat: jest.fn(),
requestPermittedChainsPermissionForOrigin: jest.fn(),
requestPermittedChainsPermissionIncrementalForOrigin: jest.fn(),
setTokenNetworkFilter: jest.fn(),
};
const response: { result?: true } = {};
const switchChain = (chainId: Hex, networkClientId: string) =>
Expand Down Expand Up @@ -55,6 +56,7 @@ describe('Ethereum Chain Utils', () => {
await switchChain('0x1', 'mainnet');

expect(mocks.setActiveNetwork).toHaveBeenCalledWith('mainnet');
expect(mocks.setTokenNetworkFilter).toHaveBeenCalledWith('0x1');
});

it('should throw an error if the switch chain approval is rejected', async () => {
Expand Down Expand Up @@ -92,6 +94,7 @@ describe('Ethereum Chain Utils', () => {
mocks.requestPermittedChainsPermissionIncrementalForOrigin,
).toHaveBeenCalledWith({ chainId: '0x1', autoApprove: true });
expect(mocks.setActiveNetwork).toHaveBeenCalledWith('mainnet');
expect(mocks.setTokenNetworkFilter).toHaveBeenCalledWith('0x1');
});

it('requests permittedChains approval without autoApprove then switches to it if autoApprove: false', async () => {
Expand All @@ -110,6 +113,7 @@ describe('Ethereum Chain Utils', () => {
mocks.requestPermittedChainsPermissionIncrementalForOrigin,
).toHaveBeenCalledWith({ chainId: '0x1', autoApprove: false });
expect(mocks.setActiveNetwork).toHaveBeenCalledWith('mainnet');
expect(mocks.setTokenNetworkFilter).toHaveBeenCalledWith('0x1');
});

it('should throw errors if the permittedChains grant fails', async () => {
Expand Down Expand Up @@ -176,6 +180,7 @@ describe('Ethereum Chain Utils', () => {
await switchChain('0x1', 'mainnet');

expect(mocks.setActiveNetwork).not.toHaveBeenCalled();
expect(mocks.setTokenNetworkFilter).not.toHaveBeenCalled();
});

it('return error about not being able to switch chain', async () => {
Expand Down Expand Up @@ -246,6 +251,7 @@ describe('Ethereum Chain Utils', () => {
await switchChain('0x1', 'mainnet');

expect(mocks.setActiveNetwork).toHaveBeenCalledWith('mainnet');
expect(mocks.setTokenNetworkFilter).toHaveBeenCalledWith('0x1');
});
},
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const switchEthereumChain = {
getCurrentChainIdForDomain: true,
requestPermittedChainsPermissionForOrigin: true,
requestPermittedChainsPermissionIncrementalForOrigin: true,
setTokenNetworkFilter: true,
},
};

Expand All @@ -32,6 +33,7 @@ async function switchEthereumChainHandler(
getCurrentChainIdForDomain,
requestPermittedChainsPermissionForOrigin,
requestPermittedChainsPermissionIncrementalForOrigin,
setTokenNetworkFilter,
},
) {
let chainId;
Expand Down Expand Up @@ -69,5 +71,6 @@ async function switchEthereumChainHandler(
getCaveat,
requestPermittedChainsPermissionForOrigin,
requestPermittedChainsPermissionIncrementalForOrigin,
setTokenNetworkFilter,
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ jest.mock('./ethereum-chain-utils', () => ({
...jest.requireActual('./ethereum-chain-utils'),
validateSwitchEthereumChainParams: jest.fn(),
switchChain: jest.fn(),
setTokenNetworkFilter: jest.fn(),
}));

const NON_INFURA_CHAIN_ID = '0x123456789';
Expand Down Expand Up @@ -46,6 +47,7 @@ const createMockedHandler = () => {
getCurrentChainIdForDomain: jest.fn().mockReturnValue(NON_INFURA_CHAIN_ID),
requestPermittedChainsPermissionForOrigin: jest.fn(),
requestPermittedChainsPermissionIncrementalForOrigin: jest.fn(),
setTokenNetworkFilter: jest.fn(),
};
const response = {};
const handler = (request) =>
Expand Down Expand Up @@ -171,6 +173,7 @@ describe('switchEthereumChainHandler', () => {
mocks.requestPermittedChainsPermissionForOrigin,
requestPermittedChainsPermissionIncrementalForOrigin:
mocks.requestPermittedChainsPermissionIncrementalForOrigin,
setTokenNetworkFilter: mocks.setTokenNetworkFilter,
},
);
});
Expand Down
9 changes: 9 additions & 0 deletions app/scripts/metamask-controller.js
Original file line number Diff line number Diff line change
Expand Up @@ -6164,6 +6164,15 @@ export default class MetamaskController extends EventEmitter {
this.networkController.getNetworkConfigurationByChainId.bind(
this.networkController,
),
setTokenNetworkFilter: (chainId) => {
const { tokenNetworkFilter } =
this.preferencesController.getPreferences();
if (chainId && Object.keys(tokenNetworkFilter).length === 1) {
this.preferencesController.setPreference('tokenNetworkFilter', {
[chainId]: true,
});
}
},
getCurrentChainIdForDomain: (domain) => {
const networkClientId =
this.selectedNetworkController.getNetworkClientIdForDomain(domain);
Expand Down

0 comments on commit 6875a8f

Please sign in to comment.