Skip to content

Commit

Permalink
Include replace_jinja_tokens method
Browse files Browse the repository at this point in the history
  • Loading branch information
joaopamaral committed Nov 28, 2024
1 parent 01932bb commit 86f5ad6
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions src/preset_cli/cli/superset/sync/dbt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import logging
import re
from collections import defaultdict
from copy import deepcopy

from sqlglot.tokens import Token
from typing import Dict, List, Optional, Set

import sqlglot
from sqlglot import Expression, ParseError, exp, parse_one
from sqlglot import Expression, ParseError, exp, parse_one, TokenType
from sqlglot.expressions import (
Alias,
Case,
Expand Down Expand Up @@ -94,9 +97,14 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, OGMetricSchema])
return sql.strip()

try:
expression = sqlglot.parse_one(sql, dialect=metric["dialect"])
tokens = expression.find_all(exp.Column)
dialect = sqlglot.Dialect.get_or_raise(metric["dialect"])
tokens = replace_jinja_tokens(dialect.tokenize(sql))
result = dialect.parser().parse(tokens, sql)
expression = result[0] if result else None
if not expression:
raise ParseError(f"No expression was parsed from '{sql}'")

tokens = expression.find_all(exp.Column)
for token in tokens:
if token.sql() in metrics:
parent_sql = get_metric_expression(token.sql(), metrics)
Expand All @@ -115,6 +123,42 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, OGMetricSchema])
raise Exception(f"Unable to generate metric expression from: {sorted_metric}")


def replace_jinja_tokens(tokens: List[Token]) -> List[Token]:
"""
Replaces Jinja-style `{{` and `}}` as block start/end.
Args:
tokens (List[Token]): List of tokens to process.
Returns:
List[Token]: List of tokens with Jinja blocks replaced.
"""
merged_tokens = []
i = 0

while i < len(tokens):
if i < len(tokens) - 1:
if tokens[i].token_type == TokenType.L_BRACE and tokens[i + 1].token_type == TokenType.L_BRACE:
block_start = tokens[i]
block_start.text += tokens[i + 1].text
block_start.token_type = TokenType.BLOCK_START
merged_tokens.append(block_start)
i += 2
continue
if tokens[i].token_type == TokenType.R_BRACE and tokens[i + 1].token_type == TokenType.R_BRACE:
block_end = tokens[i]
block_end.text += tokens[i + 1].text
block_end.token_type = TokenType.BLOCK_END
merged_tokens.append(block_end)
i += 2
continue

merged_tokens.append(tokens[i])
i += 1

return merged_tokens


def apply_filters(sql: str, filters: List[FilterSchema]) -> str:
"""
Apply filters to SQL expression.
Expand Down

0 comments on commit 86f5ad6

Please sign in to comment.