Skip to content

Commit

Permalink
Hybrid of a11y tree & DOM for input to observe (#459)
Browse files Browse the repository at this point in the history
* include backendDOMNodeId

* skip ax nodeId if negative

* replace role with dom tag name if none or generic

* add xpath to AXNode type

* revert unnecessary changed lines

* revert more unnecessary changed lines

* changeset

* speedup

* prettier

* prune before updating roles

* take xpath out of AXnode type

* rm commented code

---------

Co-authored-by: Miguel <[email protected]>
  • Loading branch information
seanmcguire12 and miguelg719 authored Feb 6, 2025
1 parent 00da6dd commit 62a29ee
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 42 deletions.
5 changes: 5 additions & 0 deletions .changeset/chilled-apes-sneeze.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": patch
---

create a11y + dom hybrid input for observe
182 changes: 143 additions & 39 deletions lib/a11y/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ export function formatSimplifiedTree(
level = 0,
): string {
const indent = " ".repeat(level);
let result = `${indent}[${node.nodeId}] ${node.role}${node.name ? `: ${node.name}` : ""}\n`;
let result = `${indent}[${node.nodeId}] ${node.role}${
node.name ? `: ${node.name}` : ""
}\n`;

if (node.children?.length) {
result += node.children
Expand All @@ -29,39 +31,113 @@ export function formatSimplifiedTree(
* 1. Removes generic/none nodes with no children
* 2. Collapses generic/none nodes with single child
* 3. Keeps generic/none nodes with multiple children but cleans their subtrees
* and attempts to resolve their role to a DOM tag name
*/
function cleanStructuralNodes(
async function cleanStructuralNodes(
node: AccessibilityNode,
): AccessibilityNode | null {
// Filter out nodes with negative IDs
page?: StagehandPage,
logger?: (logLine: LogLine) => void,
): Promise<AccessibilityNode | null> {
// 1) Filter out nodes with negative IDs
if (node.nodeId && parseInt(node.nodeId) < 0) {
return null;
}

// Base case: leaf node
if (!node.children) {
// 2) Base case: if no children exist, this is effectively a leaf.
// If it's "generic" or "none", we remove it; otherwise, keep it.
if (!node.children || node.children.length === 0) {
return node.role === "generic" || node.role === "none" ? null : node;
}

// Recursively clean children
const cleanedChildren = node.children
.map((child) => cleanStructuralNodes(child))
.filter(Boolean) as AccessibilityNode[];

// Handle generic/none nodes specially
// 3) Recursively clean children
const cleanedChildrenPromises = node.children.map((child) =>
cleanStructuralNodes(child, page, logger),
);
const resolvedChildren = await Promise.all(cleanedChildrenPromises);
const cleanedChildren = resolvedChildren.filter(
(child): child is AccessibilityNode => child !== null,
);

// 4) **Prune** "generic" or "none" nodes first,
// before resolving them to their tag names.
if (node.role === "generic" || node.role === "none") {
if (cleanedChildren.length === 1) {
// Collapse single-child generic nodes
// Collapse single-child structural node
return cleanedChildren[0];
} else if (cleanedChildren.length > 1) {
// Keep generic nodes with multiple children
return { ...node, children: cleanedChildren };
} else if (cleanedChildren.length === 0) {
// Remove empty structural node
return null;
}
// If we have multiple children, we keep this node as a container.
// We'll update role below if needed.
}

// 5) If we still have a "generic"/"none" node after pruning
// (i.e., because it had multiple children), now we try
// to resolve and replace its role with the DOM tag name.
if (
page &&
logger &&
node.backendDOMNodeId !== undefined &&
(node.role === "generic" || node.role === "none")
) {
try {
const { object } = await page.sendCDP<{
object: { objectId?: string };
}>("DOM.resolveNode", {
backendNodeId: node.backendDOMNodeId,
});

if (object && object.objectId) {
try {
// Get the tagName for the node
const { result } = await page.sendCDP<{
result: { type: string; value?: string };
}>("Runtime.callFunctionOn", {
objectId: object.objectId,
functionDeclaration: `
function() {
return this.tagName ? this.tagName.toLowerCase() : "";
}
`,
returnByValue: true,
});

// If we got a tagName, update the node's role
if (result?.value) {
node.role = result.value;
}
} catch (tagNameError) {
logger({
category: "observation",
message: `Could not fetch tagName for node ${node.backendDOMNodeId}`,
level: 2,
auxiliary: {
error: {
value: tagNameError.message,
type: "string",
},
},
});
}
}
} catch (resolveError) {
logger({
category: "observation",
message: `Could not resolve DOM node ID ${node.backendDOMNodeId}`,
level: 2,
auxiliary: {
error: {
value: resolveError.message,
type: "string",
},
},
});
}
// Remove generic nodes with no children
return null;
}

// For non-generic nodes, keep them if they have children after cleaning
// 6) Return the updated node.
// If it has children, update them; otherwise keep it as-is.
return cleanedChildren.length > 0
? { ...node, children: cleanedChildren }
: node;
Expand All @@ -73,13 +149,23 @@ function cleanStructuralNodes(
* @param nodes - Flat array of accessibility nodes from the CDP
* @returns Object containing both the tree structure and a simplified string representation
*/
export function buildHierarchicalTree(nodes: AccessibilityNode[]): TreeResult {
export async function buildHierarchicalTree(
nodes: AccessibilityNode[],
page?: StagehandPage,
logger?: (logLine: LogLine) => void,
): Promise<TreeResult> {
// Map to store processed nodes for quick lookup
const nodeMap = new Map<string, AccessibilityNode>();

// First pass: Create nodes that are meaningful
// We only keep nodes that either have a name or children to avoid cluttering the tree
nodes.forEach((node) => {
// Skip node if its ID is negative (e.g., "-1000002014")
const nodeIdValue = parseInt(node.nodeId, 10);
if (nodeIdValue < 0) {
return;
}

const hasChildren = node.childIds && node.childIds.length > 0;
const hasValidName = node.name && node.name.trim() !== "";
const isInteractive =
Expand All @@ -99,6 +185,9 @@ export function buildHierarchicalTree(nodes: AccessibilityNode[]): TreeResult {
...(hasValidName && { name: node.name }), // Only include name if it exists and isn't empty
...(node.description && { description: node.description }),
...(node.value && { value: node.value }),
...(node.backendDOMNodeId !== undefined && {
backendDOMNodeId: node.backendDOMNodeId,
}),
});
});

Expand All @@ -119,13 +208,18 @@ export function buildHierarchicalTree(nodes: AccessibilityNode[]): TreeResult {
});

// Final pass: Build the root-level tree and clean up structural nodes
const finalTree = nodes
const rootNodes = nodes
.filter((node) => !node.parentId && nodeMap.has(node.nodeId)) // Get root nodes
.map((node) => nodeMap.get(node.nodeId))
.filter(Boolean)
.map((node) => cleanStructuralNodes(node))
.filter(Boolean) as AccessibilityNode[];

const cleanedTreePromises = rootNodes.map((node) =>
cleanStructuralNodes(node, page, logger),
);
const finalTree = (await Promise.all(cleanedTreePromises)).filter(
Boolean,
) as AccessibilityNode[];

// Generate a simplified string representation of the tree
const simplifiedFormat = finalTree
.map((node) => formatSimplifiedTree(node))
Expand All @@ -137,29 +231,43 @@ export function buildHierarchicalTree(nodes: AccessibilityNode[]): TreeResult {
};
}

/**
* Retrieves the full accessibility tree via CDP and transforms it into a hierarchical structure.
*/
export async function getAccessibilityTree(
page: StagehandPage,
logger: (logLine: LogLine) => void,
) {
): Promise<TreeResult> {
await page.enableCDP("Accessibility");

try {
// Fetch the full accessibility tree from Chrome DevTools Protocol
const { nodes } = await page.sendCDP<{ nodes: AXNode[] }>(
"Accessibility.getFullAXTree",
);
const startTime = Date.now();

// Extract specific sources
const sources = nodes.map((node) => ({
role: node.role?.value,
name: node.name?.value,
description: node.description?.value,
value: node.value?.value,
nodeId: node.nodeId,
parentId: node.parentId,
childIds: node.childIds,
}));
// Transform into hierarchical structure
const hierarchicalTree = buildHierarchicalTree(sources);
const hierarchicalTree = await buildHierarchicalTree(
nodes.map((node) => ({
role: node.role?.value,
name: node.name?.value,
description: node.description?.value,
value: node.value?.value,
nodeId: node.nodeId,
backendDOMNodeId: node.backendDOMNodeId,
parentId: node.parentId,
childIds: node.childIds,
})),
page,
logger,
);

logger({
category: "observation",
message: `got accessibility tree in ${Date.now() - startTime}ms`,
level: 1,
});

return hierarchicalTree;
} catch (error) {
Expand Down Expand Up @@ -258,7 +366,6 @@ export async function performPlaywrightMethod(
method: string,
args: unknown[],
xpath: string,
// domSettleTimeoutMs?: number,
) {
const locator = stagehandPage.locator(`xpath=${xpath}`).first();
const initialUrl = stagehandPage.url();
Expand Down Expand Up @@ -503,7 +610,6 @@ export async function performPlaywrightMethod(
await newOpenedTab.close();
await stagehandPage.goto(newOpenedTab.url());
await stagehandPage.waitForLoadState("domcontentloaded");
// await stagehandPage._waitForSettledDom(domSettleTimeoutMs);
}

await Promise.race([
Expand Down Expand Up @@ -564,6 +670,4 @@ export async function performPlaywrightMethod(
`Method ${method} not supported`,
);
}

// await stagehandPage._waitForSettledDom(domSettleTimeoutMs);
}
3 changes: 1 addition & 2 deletions lib/handlers/observeHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export class StagehandObserveHandler {
isUsingAccessibilityTree: useAccessibilityTree,
returnAction,
});

const elementsWithSelectors = await Promise.all(
observationResponse.elements.map(async (element) => {
const { elementId, ...rest } = element;
Expand Down Expand Up @@ -137,7 +138,6 @@ export class StagehandObserveHandler {
message: `Invalid object ID returned for element: ${elementId}`,
level: 1,
});
return null;
}

const xpath = await getXPathByResolvedObjectId(
Expand All @@ -151,7 +151,6 @@ export class StagehandObserveHandler {
message: `Empty xpath returned for element: ${elementId}`,
level: 1,
});
return null;
}

return {
Expand Down
2 changes: 1 addition & 1 deletion lib/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ You will be given:
1. a instruction of elements to observe
2. ${
isUsingAccessibilityTree
? "a hierarchical accessibility tree showing the semantic structure of the page"
? "a hierarchical accessibility tree showing the semantic structure of the page. The tree is a hybrid of the DOM and the accessibility tree."
: "a numbered list of possible elements"
}
Expand Down
2 changes: 2 additions & 0 deletions types/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export interface AXNode {
description?: { value: string };
value?: { value: string };
nodeId: string;
backendDOMNodeId?: number;
parentId?: string;
childIds?: string[];
}
Expand All @@ -17,6 +18,7 @@ export type AccessibilityNode = {
childIds?: string[];
parentId?: string;
nodeId?: string;
backendDOMNodeId?: number;
};

export interface TreeResult {
Expand Down

0 comments on commit 62a29ee

Please sign in to comment.