Skip to content

Commit

Permalink
modularize package for siemrules (#15)
Browse files Browse the repository at this point in the history
* modularize package for siem rules

* Create create_release.yml

---------

Co-authored-by: David G <[email protected]>
  • Loading branch information
fqrious and himynamesdave authored Jan 16, 2025
1 parent f928abd commit 015ec23
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 21 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/create_release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Create Release
run-name: creating release
on:
workflow_dispatch:
push:
branches:
- main

jobs:
create-release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install pypa/build
run: python3 -m pip install build --user

- name: Build a binary wheel and a source tarball
run: python3 -m build

- name: Make release
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
REF_NAME="${{ github.ref_name }}-$(date +"%Y-%m-%d-%H-%M-%S")"
gh release create "$REF_NAME" --repo '${{ github.repository }}' --notes ""
gh release upload "$REF_NAME" dist/** --repo '${{ github.repository }}'
38 changes: 38 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "txt2detection"

version = "0.0.1"
authors = [
{ name="DOGESEC", email="[email protected]" },
]
description = "txt2detection is a tool"
readme = "README.md"
requires-python = ">=3.11"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
dependencies = [
"stix2",
"python-arango>=8.1.3; python_version >= '3.8'",
"tqdm>=4.66.4; python_version >= '3.7'",
"jsonschema>=4.22.0; python_version >= '3.8'",
"requests>=2.31.0; python_version >= '3.7'",
"python-dotenv>=1.0.1",
"pyyaml",
]
[project.urls]
Homepage = "https://github.com/muchdogesec/txt2detection"
Issues = "https://github.com/muchdogesec/txt2detection/issues"

[project.scripts]
txt2detection = "txt2detection.__main__:main"


[tool.hatch.build.targets.wheel.force-include]
"config" = "txt2detection/config"
29 changes: 19 additions & 10 deletions txt2detection/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
from dataclasses import dataclass
from functools import partial
from itertools import chain
import json
import os
from pathlib import Path
import logging
import sys
import uuid
from stix2 import Identity, parse as parse_stix

from txt2detection.ai_extractor.base import BaseAIExtractor
from txt2detection.utils import validate_token_count

def configureLogging():
# Configure logging
Expand Down Expand Up @@ -37,18 +41,23 @@ def setLogFile(logger, file: Path):
from dotenv import load_dotenv
from .bundler import Bundler


from .utils import load_detection_languages, parse_model

def parse_identity(str):
return Identity(**json.loads(str))

@dataclass
class Args:
input_file: str
name: str
tlp_level: str
labels: list[str]
created: datetime
use_identity: str
use_identity: Identity
detection_language: str
ai_provider: BaseAIExtractor
report_id: uuid.UUID

def parse_created(value):
"""Convert the created timestamp to a datetime object."""
Expand All @@ -62,14 +71,15 @@ def parse_args():
parser = argparse.ArgumentParser(description='Convert text file to detection format.')

parser.add_argument('--input_file', required=True, help='The file to be converted. Must be .txt')
parser.add_argument('--report_id', type=uuid.UUID, help='report_id to use for generated report')
parser.add_argument('--name', required=True, help='Name of file, max 72 chars. Will be used in the STIX Report Object created.')
parser.add_argument('--tlp_level', choices=['clear', 'green', 'amber', 'amber_strict', 'red'], default='clear',
help='Options are clear, green, amber, amber_strict, red. Default is clear if not passed.')
parser.add_argument('--labels', type=lambda s: s.split(','),
help='Comma-separated list of labels. Case-insensitive (will be converted to lower-case). Allowed a-z, 0-9.')
parser.add_argument('--created', type=parse_created,
help='Explicitly set created time in format YYYY-MM-DDTHH:MM:SS.sssZ. Default is current time.')
parser.add_argument('--use_identity',
parser.add_argument('--use_identity', type=parse_identity,
help='Pass a full STIX 2.1 identity object (properly escaped). Validated by the STIX2 library. Default is SIEM Rules identity.')
parser.add_argument('--detection_language', required=True,
help='Detection rule language for the output. Check config/detection_languages.yaml for available keys.',
Expand All @@ -86,26 +96,25 @@ def parse_args():
if not os.path.isfile(args.input_file):
parser.error(f"The specified input file does not exist: {args.input_file}")

if not args.report_id:
args.report_id = Bundler.generate_report_id(args.use_identity.id if args.use_identity else None, args.created, args.name)

return args

def validate_token_count(max_tokens, input, extractor: BaseAIExtractor):
logging.info('INPUT_TOKEN_LIMIT = %d', max_tokens)
token_count = extractor.count_tokens(input)
logging.info('TOKEN COUNT FOR %s: %d', extractor.extractor_name, token_count)
if token_count > max_tokens:
raise Exception(f"{extractor.extractor_name}: input_file token count ({token_count}) exceeds INPUT_TOKEN_LIMIT ({max_tokens})")


def main(args: Args):
setLogFile(logging.root, Path(f"logs/log-{int(args.created.timestamp())}.log"))
setLogFile(logging.root, Path(f"logs/log-{args.report_id}.log"))
input_str = Path(args.input_file).read_text()
validate_token_count(int(os.getenv('INPUT_TOKEN_LIMIT', 0)), input_str, args.ai_provider)
detections = args.ai_provider.get_detections(input_str, detection_language=args.detection_language)
bundler = Bundler(args.name, args.detection_language.slug, args.use_identity, args.tlp_level, input_str, 0, args.labels)
bundler = Bundler(args.name, args.detection_language.slug, args.use_identity, args.tlp_level, input_str, 0, args.labels, report_id=args.report_id)
bundler.bundle_detections(detections)
out = bundler.to_json()

output_path = Path("./output")/f"{bundler.bundle.id}.json"
data_path = output_path.with_name(f"data--{args.report_id}.json")
output_path.parent.mkdir(exist_ok=True)
output_path.write_text(out)
data_path.write_text(detections.model_dump_json(indent=4))
logging.info(f"Writing bundle output to `{output_path}`")
9 changes: 6 additions & 3 deletions txt2detection/ai_extractor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from .base import _ai_extractor_registry as ALL_AI_EXTRACTORS

from .base import BaseAIExtractor
from .openai import OpenAIExtractor
from .anthropic import AnthropicAIExtractor
from .gemini import GeminiAIExtractor

for path in ["openai", "anthropic", "gemini"]:
try:
__import__(__package__ + "." + path)
except Exception as e:
logging.warning("%s not installed", path, exc_info=True)
2 changes: 1 addition & 1 deletion txt2detection/ai_extractor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class BaseAIExtractor():
))


def get_detections(self, input_text, detection_language):
def get_detections(self, input_text, detection_language) -> DetectionContainer:
logging.info('getting detections')
return self.llm.structured_predict(
prompt=self.detection_template,
Expand Down
14 changes: 11 additions & 3 deletions txt2detection/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ class Bundler:
]
})

@classmethod
def generate_report_id(cls, created_by_ref, created, name):
if not created_by_ref:
created_by_ref = cls.default_identity['id']
return str(
uuid.uuid5(UUID_NAMESPACE, f"{created_by_ref}+{created}+{name}")
)

def __init__(
self,
name,
Expand All @@ -169,13 +177,13 @@ def __init__(
confidence,
labels,
created=dt.now(),
report_id=None,
) -> None:
self.created = created
self.identity = identity or self.default_identity
self.tlp_level = TLP_LEVEL.get(tlp_level)
self.uuid = str(
uuid.uuid5(UUID_NAMESPACE, f"{self.identity.id}+{self.created}+{name}")
)
self.uuid = report_id or self.generate_report_id(self.identity.id, self.created, name)

self.detection_language = detection_language
self.job_id = f"report--{self.uuid}"
self.report = Report(
Expand Down
22 changes: 18 additions & 4 deletions txt2detection/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from pathlib import Path
from types import SimpleNamespace
import yaml
from .ai_extractor import ALL_AI_EXTRACTORS
from .ai_extractor import ALL_AI_EXTRACTORS, BaseAIExtractor
from importlib import resources
import txt2detection
import logging

class DetectionLanguage(SimpleNamespace):
pass

def load_detection_languages():
def load_detection_languages(path = Path("config/detection_languages.yaml")):
if not path.exists():
path = resources.files(txt2detection) / "config/detection_languages.yaml"
langs = {}
for k, v in yaml.safe_load(Path("config/detection_languages.yaml").open()).items():
for k, v in yaml.safe_load(path.open()).items():
v["slug"] = k
langs[k] = DetectionLanguage(**v)
return langs
Expand All @@ -21,4 +26,13 @@ def parse_model(value: str):
provider = ALL_AI_EXTRACTORS[provider]
if len(splits) == 2:
return provider(model=splits[1])
return provider()
return provider()



def validate_token_count(max_tokens, input, extractor: BaseAIExtractor):
logging.info('INPUT_TOKEN_LIMIT = %d', max_tokens)
token_count = extractor.count_tokens(input)
logging.info('TOKEN COUNT FOR %s: %d', extractor.extractor_name, token_count)
if token_count > max_tokens:
raise Exception(f"{extractor.extractor_name}: input_file token count ({token_count}) exceeds INPUT_TOKEN_LIMIT ({max_tokens})")

0 comments on commit 015ec23

Please sign in to comment.