Feat/firecrawl data source (#5232)

Co-authored-by: Nicolas <nicolascamara29@gmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Jyong 2024-06-15 02:46:02 +08:00 committed by GitHub
parent 918ebe1620
commit ba5f8afaa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1174 additions and 64 deletions

View File

@ -215,4 +215,5 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_EXECUTION_TIME=1200

View File

@ -29,13 +29,13 @@ from .app import (
)
# Import auth controllers
from .auth import activate, data_source_oauth, login, oauth
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
# Import billing controllers
from .billing import billing
# Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
# Import explore controllers
from .explore import (

View File

@ -0,0 +1,67 @@
from flask_login import current_user
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..setup import setup_required
from ..wraps import account_initialization_required
class ApiKeyAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
if data_source_api_key_bindings:
return {
'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in
data_source_api_key_bindings]}
return {'settings': []}
class ApiKeyAuthDataSourceBinding(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {'result': 'success'}, 200
class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
return {'result': 'success'}, 200
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')

View File

@ -0,0 +1,7 @@
from libs.exception import BaseHTTPException
class ApiKeyAuthFailedError(BaseHTTPException):
error_code = 'auth_failed'
description = "{message}"
code = 500

View File

@ -16,7 +16,7 @@ from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required
from models.dataset import Document
from models.source import DataSourceBinding
from models.source import DataSourceOauthBinding
from services.dataset_service import DatasetService, DocumentService
from tasks.document_indexing_sync_task import document_indexing_sync_task
@ -29,9 +29,9 @@ class DataSourceApi(Resource):
@marshal_with(integrate_list_fields)
def get(self):
# get workspace data source integrates
data_source_integrates = db.session.query(DataSourceBinding).filter(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.disabled == False
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False
).all()
base_url = request.url_root.rstrip('/')
@ -71,7 +71,7 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
data_source_binding = DataSourceBinding.query.filter_by(
data_source_binding = DataSourceOauthBinding.query.filter_by(
id=binding_id
).first()
if data_source_binding is None:
@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource):
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info['notion_page_id'])
# get all authorized pages
data_source_bindings = DataSourceBinding.query.filter_by(
data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id,
provider='notion',
disabled=False
@ -163,12 +163,12 @@ class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id)
page_id = str(page_id)
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:

View File

@ -315,6 +315,22 @@ class DatasetIndexingEstimateApi(Resource):
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
elif args['info_list']['data_source_type'] == 'website_crawl':
website_info_list = args['info_list']['website_info_list']
for url in website_info_list['urls']:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": website_info_list['provider'],
"job_id": website_info_list['job_id'],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": 'crawl',
"only_main_content": website_info_list['only_main_content']
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
@ -519,6 +535,7 @@ class DatasetRetrievalSettingMockApi(Resource):
raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetErrorDocs(Resource):
@setup_required
@login_required

View File

@ -465,6 +465,20 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == 'website_crawl':
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": data_source_info['provider'],
"job_id": data_source_info['job_id'],
"url": data_source_info['url'],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info['mode'],
"only_main_content": data_source_info['only_main_content']
},
document_model=document.doc_form
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
@ -952,6 +966,33 @@ class DocumentRenameApi(DocumentResource):
return document
class WebsiteDocumentSyncApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""sync website document."""
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
if document.tenant_id != current_user.current_tenant_id:
raise Forbidden('No permission.')
if document.data_source_type != 'website_crawl':
raise ValueError('Document is not a website document.')
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
# sync document
DocumentService.sync_website_document(dataset_id, document)
return {'result': 'success'}, 200
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
@ -980,3 +1021,5 @@ api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uui
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
api.add_resource(DocumentRenameApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')
api.add_resource(WebsiteDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync')

View File

@ -73,6 +73,12 @@ class InvalidMetadataError(BaseHTTPException):
code = 400
class WebsiteCrawlError(BaseHTTPException):
error_code = 'crawl_failed'
description = "{message}"
code = 500
class DatasetInUseError(BaseHTTPException):
error_code = 'dataset_in_use'
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."

View File

@ -0,0 +1,49 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.website_service import WebsiteService
class WebsiteCrawlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, choices=['firecrawl'],
required=True, nullable=True, location='json')
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
WebsiteService.document_create_args_validate(args)
# crawl url
try:
result = WebsiteService.crawl_url(args)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
class WebsiteCrawlStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
args = parser.parse_args()
# get crawl status
try:
result = WebsiteService.get_crawl_status(job_id, args['provider'])
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
api.add_resource(WebsiteCrawlApi, '/website/crawl')
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')

View File

@ -339,7 +339,7 @@ class IndexingRunner:
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
-> list[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
return []
data_source_info = dataset_document.data_source_info_dict
@ -375,6 +375,23 @@ class IndexingRunner:
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
elif dataset_document.data_source_type == 'website_crawl':
if (not data_source_info or 'provider' not in data_source_info
or 'url' not in data_source_info or 'job_id' not in data_source_info):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": data_source_info['provider'],
"job_id": data_source_info['job_id'],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info['url'],
"mode": data_source_info['mode'],
"only_main_content": data_source_info['only_main_content']
},
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,

View File

@ -124,7 +124,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
default=float(credentials.get('presence_penalty', 0)),
min=-2,
max=2
)
),
],
pricing=PriceConfig(
input=Decimal(cred_with_endpoint.get('input_price', 0)),

View File

@ -4,3 +4,4 @@ from enum import Enum
class DatasourceType(Enum):
FILE = "upload_file"
NOTION = "notion_import"
WEBSITE = "website_crawl"

View File

@ -1,3 +1,5 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict
from models.dataset import Document
@ -19,14 +21,33 @@ class NotionInfo(BaseModel):
super().__init__(**data)
class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
job_id: str
url: str
mode: str
tenant_id: str
only_main_content: bool = False
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)
class ExtractSetting(BaseModel):
"""
Model class for provider response.
"""
datasource_type: str
upload_file: UploadFile = None
notion_info: NotionInfo = None
document_model: str = None
upload_file: Optional[UploadFile]
notion_info: Optional[NotionInfo]
website_info: Optional[WebsiteInfo]
document_model: Optional[str]
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data) -> None:

View File

@ -11,6 +11,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor
@ -154,5 +155,17 @@ class ExtractProcessor:
tenant_id=extract_setting.notion_info.tenant_id,
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
if extract_setting.website_info.provider == 'firecrawl':
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
tenant_id=extract_setting.website_info.tenant_id,
mode=extract_setting.website_info.mode,
only_main_content=extract_setting.website_info.only_main_content
)
return extractor.extract()
else:
raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}")
else:
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")

View File

@ -0,0 +1,132 @@
import json
import time
import requests
from extensions.ext_storage import storage
class FirecrawlApp:
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url or 'https://api.firecrawl.dev'
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
raise ValueError('No API key provided')
def scrape_url(self, url, params=None) -> dict:
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
json_data = {'url': url}
if params:
json_data.update(params)
response = requests.post(
f'{self.base_url}/v0/scrape',
headers=headers,
json=json_data
)
if response.status_code == 200:
response = response.json()
if response['success'] == True:
data = response['data']
return {
'title': data.get('metadata').get('title'),
'description': data.get('metadata').get('description'),
'source_url': data.get('metadata').get('sourceURL'),
'markdown': data.get('markdown')
}
else:
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
elif response.status_code in [402, 409, 500]:
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
else:
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
def crawl_url(self, url, params=None) -> str:
start_time = time.time()
headers = self._prepare_headers()
json_data = {'url': url}
if params:
json_data.update(params)
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
if response.status_code == 200:
job_id = response.json().get('jobId')
return job_id
else:
self._handle_error(response, 'start crawl job')
def check_crawl_status(self, job_id) -> dict:
headers = self._prepare_headers()
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get('status') == 'completed':
total = crawl_status_response.get('total', 0)
if total == 0:
raise Exception('Failed to check crawl status. Error: No page found')
data = crawl_status_response.get('data', [])
url_data_list = []
for item in data:
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
url_data = {
'title': item.get('metadata').get('title'),
'description': item.get('metadata').get('description'),
'source_url': item.get('metadata').get('sourceURL'),
'markdown': item.get('markdown')
}
url_data_list.append(url_data)
if url_data_list:
file_key = 'website_files/' + job_id + '.txt'
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
return {
'status': 'completed',
'total': crawl_status_response.get('total'),
'current': crawl_status_response.get('current'),
'data': url_data_list
}
else:
return {
'status': crawl_status_response.get('status'),
'total': crawl_status_response.get('total'),
'current': crawl_status_response.get('current'),
'data': []
}
else:
self._handle_error(response, 'check crawl status')
def _prepare_headers(self):
return {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
for attempt in range(retries):
response = requests.post(url, headers=headers, json=data)
if response.status_code == 502:
time.sleep(backoff_factor * (2 ** attempt))
else:
return response
return response
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
for attempt in range(retries):
response = requests.get(url, headers=headers)
if response.status_code == 502:
time.sleep(backoff_factor * (2 ** attempt))
else:
return response
return response
def _handle_error(self, response, action):
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')

View File

@ -0,0 +1,60 @@
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from services.website_service import WebsiteService
class FirecrawlWebExtractor(BaseExtractor):
"""
Crawl and scrape websites and return content in clean llm-ready markdown.
Args:
url: The URL to scrape.
api_key: The API key for Firecrawl.
base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
"""
def __init__(
self,
url: str,
job_id: str,
tenant_id: str,
mode: str = 'crawl',
only_main_content: bool = False
):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id
self.tenant_id = tenant_id
self.mode = mode
self.only_main_content = only_main_content
def extract(self) -> list[Document]:
"""Extract content from the URL."""
documents = []
if self.mode == 'crawl':
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
if crawl_data is None:
return []
document = Document(page_content=crawl_data.get('markdown', ''),
metadata={
'source_url': crawl_data.get('source_url'),
'description': crawl_data.get('description'),
'title': crawl_data.get('title')
}
)
documents.append(document)
elif self.mode == 'scrape':
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
self.only_main_content)
document = Document(page_content=scrape_data.get('markdown', ''),
metadata={
'source_url': scrape_data.get('source_url'),
'description': scrape_data.get('description'),
'title': scrape_data.get('title')
}
)
documents.append(document)
return documents

View File

@ -9,7 +9,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding
from models.source import DataSourceOauthBinding
logger = logging.getLogger(__name__)
@ -345,12 +345,12 @@ class NotionExtractor(BaseExtractor):
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
)
).first()

View File

@ -0,0 +1,64 @@
# [REVIEW] Implement if Needed? Do we need a new type of data source
from abc import abstractmethod
import requests
from api.models.source import DataSourceBearerBinding
from flask_login import current_user
from extensions.ext_database import db
class BearerDataSource:
def __init__(self, api_key: str, api_base_url: str):
self.api_key = api_key
self.api_base_url = api_base_url
@abstractmethod
def validate_bearer_data_source(self):
"""
Validate the data source
"""
class FireCrawlDataSource(BearerDataSource):
def validate_bearer_data_source(self):
TEST_CRAWL_SITE_URL = "https://www.google.com"
FIRECRAWL_API_VERSION = "v0"
test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
data = {
"url": TEST_CRAWL_SITE_URL,
}
response = requests.get(test_api_endpoint, headers=headers, json=data)
return response.json().get("status") == "success"
def save_credentials(self):
# save data source binding
data_source_binding = DataSourceBearerBinding.query.filter(
db.and_(
DataSourceBearerBinding.tenant_id == current_user.current_tenant_id,
DataSourceBearerBinding.provider == 'firecrawl',
DataSourceBearerBinding.endpoint_url == self.api_base_url,
DataSourceBearerBinding.bearer_key == self.api_key
)
).first()
if data_source_binding:
data_source_binding.disabled = False
db.session.commit()
else:
new_data_source_binding = DataSourceBearerBinding(
tenant_id=current_user.current_tenant_id,
provider='firecrawl',
endpoint_url=self.api_base_url,
bearer_key=self.api_key
)
db.session.add(new_data_source_binding)
db.session.commit()

View File

@ -4,7 +4,7 @@ import requests
from flask_login import current_user
from extensions.ext_database import db
from models.source import DataSourceBinding
from models.source import DataSourceOauthBinding
class OAuthDataSource:
@ -63,11 +63,11 @@ class NotionOAuth(OAuthDataSource):
'total': len(pages)
}
# save data source binding
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.access_token == access_token
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.access_token == access_token
)
).first()
if data_source_binding:
@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding.disabled = False
db.session.commit()
else:
new_data_source_binding = DataSourceBinding(
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
@ -98,11 +98,11 @@ class NotionOAuth(OAuthDataSource):
'total': len(pages)
}
# save data source binding
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.access_token == access_token
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.access_token == access_token
)
).first()
if data_source_binding:
@ -110,7 +110,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding.disabled = False
db.session.commit()
else:
new_data_source_binding = DataSourceBinding(
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
@ -121,12 +121,12 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str):
# save data source binding
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.id == binding_id,
DataSourceBinding.disabled == False
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False
)
).first()
if data_source_binding:

View File

@ -0,0 +1,67 @@
"""add-api-key-auth-binding
Revision ID: 7b45942e39bb
Revises: 47cc7df8c4f3
Create Date: 2024-05-14 07:31:29.702766
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '7b45942e39bb'
down_revision = '4e99a8df00ff'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('data_source_api_key_auth_bindings',
sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.StringUUID(), nullable=False),
sa.Column('category', sa.String(length=255), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('credentials', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True),
sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey')
)
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False)
batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
batch_op.drop_index('source_binding_tenant_id_idx')
batch_op.drop_index('source_info_idx')
op.rename_table('data_source_bindings', 'data_source_oauth_bindings')
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
batch_op.drop_index('source_info_idx', postgresql_using='gin')
batch_op.drop_index('source_binding_tenant_id_idx')
op.rename_table('data_source_oauth_bindings', 'data_source_bindings')
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
batch_op.create_index('source_info_idx', ['source_info'], unique=False)
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx')
batch_op.drop_index('data_source_api_key_auth_binding_provider_idx')
op.drop_table('data_source_api_key_auth_bindings')
# ### end Alembic commands ###

View File

@ -270,7 +270,7 @@ class Document(db.Model):
255), nullable=False, server_default=db.text("'text_model'::character varying"))
doc_language = db.Column(db.String(255), nullable=True)
DATA_SOURCES = ['upload_file', 'notion_import']
DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl']
@property
def display_status(self):
@ -322,7 +322,7 @@ class Document(db.Model):
'created_at': file_detail.created_at.timestamp()
}
}
elif self.data_source_type == 'notion_import':
elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl':
return json.loads(self.data_source_info)
return {}

View File

@ -1,11 +1,13 @@
import json
from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db
from models import StringUUID
class DataSourceBinding(db.Model):
__tablename__ = 'data_source_bindings'
class DataSourceOauthBinding(db.Model):
__tablename__ = 'data_source_oauth_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
db.Index('source_binding_tenant_id_idx', 'tenant_id'),
@ -20,3 +22,33 @@ class DataSourceBinding(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
class DataSourceApiKeyAuthBinding(db.Model):
__tablename__ = 'data_source_api_key_auth_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'),
db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'),
db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
category = db.Column(db.String(255), nullable=False)
provider = db.Column(db.String(255), nullable=False)
credentials = db.Column(db.Text, nullable=True) # JSON
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
def to_dict(self):
return {
'id': self.id,
'tenant_id': self.tenant_id,
'category': self.category,
'provider': self.provider,
'credentials': json.loads(self.credentials),
'created_at': self.created_at.timestamp(),
'updated_at': self.updated_at.timestamp(),
'disabled': self.disabled
}

View File

@ -78,6 +78,9 @@ CODE_MAX_STRING_LENGTH = "80000"
CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194"
CODE_EXECUTION_API_KEY="dify-sandbox"
FIRECRAWL_API_KEY = "fc-"
[tool.poetry]
name = "dify-api"

View File

View File

@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
self.credentials = credentials
@abstractmethod
def validate_credentials(self):
raise NotImplementedError

View File

@ -0,0 +1,14 @@
from services.auth.firecrawl import FirecrawlAuth
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
if provider == 'firecrawl':
self.auth = FirecrawlAuth(credentials)
else:
raise ValueError('Invalid provider')
def validate_credentials(self):
return self.auth.validate_credentials()

View File

@ -0,0 +1,70 @@
import json
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).all()
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
if auth_result:
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
args['credentials']['config']['api_key'] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
data_source_api_key_binding.tenant_id = tenant_id
data_source_api_key_binding.category = args['category']
data_source_api_key_binding.provider = args['provider']
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).first()
if not data_source_api_key_bindings:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return credentials
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.id == binding_id
).first()
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
if 'category' not in args or not args['category']:
raise ValueError('category is required')
if 'provider' not in args or not args['provider']:
raise ValueError('provider is required')
if 'credentials' not in args or not args['credentials']:
raise ValueError('credentials is required')
if not isinstance(args['credentials'], dict):
raise ValueError('credentials must be a dictionary')
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
raise ValueError('auth_type is required')

View File

@ -0,0 +1,56 @@
import json
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get('auth_type')
if auth_type != 'bearer':
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
self.api_key = credentials.get('config').get('api_key', None)
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
if not self.api_key:
raise ValueError('No API key provided')
def validate_credentials(self):
headers = self._prepare_headers()
options = {
'url': 'https://example.com',
'crawlerOptions': {
'excludes': [],
'includes': [],
'limit': 1
},
'pageOptions': {
'onlyMainContent': True
}
}
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in [402, 409, 500]:
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
else:
if response.text:
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')

View File

@ -31,7 +31,7 @@ from models.dataset import (
DocumentSegment,
)
from models.model import UploadFile
from models.source import DataSourceBinding
from models.source import DataSourceOauthBinding
from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
class DatasetService:
@ -508,18 +509,40 @@ class DocumentService:
@staticmethod
def retry_document(dataset_id: str, documents: list[Document]):
for document in documents:
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
cache_result = redis_client.get(retry_indexing_cache_key)
if cache_result is not None:
raise ValueError("Document is being retried, please try again later")
# retry document indexing
document.indexing_status = 'waiting'
db.session.add(document)
db.session.commit()
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids)
@staticmethod
def sync_website_document(dataset_id: str, document: Document):
# add sync flag
sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
cache_result = redis_client.get(sync_indexing_cache_key)
if cache_result is not None:
raise ValueError("Document is being synced, please try again later")
# sync document indexing
document.indexing_status = 'waiting'
data_source_info = document.data_source_info_dict
data_source_info['mode'] = 'scrape'
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
db.session.add(document)
db.session.commit()
redis_client.setex(sync_indexing_cache_key, 600, 1)
sync_website_document_indexing_task.delay(dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
if document:
@ -545,6 +568,9 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
count = len(website_info['urls'])
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@ -683,12 +709,12 @@ class DocumentService:
exist_document[data_source_info['notion_page_id']] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
@ -717,6 +743,28 @@ class DocumentService:
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
urls = website_info['urls']
for url in urls:
data_source_info = {
'url': url,
'provider': website_info['provider'],
'job_id': website_info['job_id'],
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, url, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
@ -818,12 +866,12 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
@ -835,6 +883,17 @@ class DocumentService:
"notion_page_icon": page['page_icon'],
"type": page['type']
}
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
urls = website_info['urls']
for url in urls:
data_source_info = {
'url': url,
'provider': website_info['provider'],
'job_id': website_info['job_id'],
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document.data_source_type = document_data["data_source"]["type"]
document.data_source_info = json.dumps(data_source_info)
document.name = file_name
@ -873,6 +932,9 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
count = len(website_info['urls'])
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@ -973,6 +1035,10 @@ class DocumentService:
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'notion_info_list']:
raise ValueError("Notion source info is required")
if args['data_source']['type'] == 'website_crawl':
if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'website_info_list']:
raise ValueError("Website source info is required")
@classmethod
def process_rule_args_validate(cls, args: dict):

View File

@ -0,0 +1,171 @@
import datetime
import json
from flask_login import current_user
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
class WebsiteService:
@classmethod
def document_create_args_validate(cls, args: dict):
if 'url' not in args or not args['url']:
raise ValueError('url is required')
if 'options' not in args or not args['options']:
raise ValueError('options is required')
if 'limit' not in args['options'] or not args['options']['limit']:
raise ValueError('limit is required')
@classmethod
def crawl_url(cls, args: dict) -> dict:
provider = args.get('provider')
url = args.get('url')
options = args.get('options')
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
crawl_sub_pages = options.get('crawl_sub_pages', False)
only_main_content = options.get('only_main_content', False)
if not crawl_sub_pages:
params = {
'crawlerOptions': {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"limit": 1,
'returnOnlyUrls': False,
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
}
else:
includes = options.get('includes').split(',') if options.get('includes') else []
excludes = options.get('excludes').split(',') if options.get('excludes') else []
params = {
'crawlerOptions': {
"includes": includes if includes else [],
"excludes": excludes if excludes else [],
"generateImgAltText": True,
"limit": options.get('limit', 1),
'returnOnlyUrls': False,
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
}
if options.get('max_depth'):
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
job_id = firecrawl_app.crawl_url(url, params)
website_crawl_time_cache_key = f'website_crawl_{job_id}'
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {
'status': 'active',
'job_id': job_id
}
else:
raise ValueError('Invalid provider')
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
result = firecrawl_app.check_crawl_status(job_id)
crawl_status_data = {
'status': result.get('status', 'active'),
'job_id': job_id,
'total': result.get('total', 0),
'current': result.get('current', 0),
'data': result.get('data', [])
}
if crawl_status_data['status'] == 'completed':
website_crawl_time_cache_key = f'website_crawl_{job_id}'
start_time = redis_client.get(website_crawl_time_cache_key)
if start_time:
end_time = datetime.datetime.now().timestamp()
time_consuming = abs(end_time - float(start_time))
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
else:
raise ValueError('Invalid provider')
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
'website',
provider)
if provider == 'firecrawl':
file_key = 'website_files/' + job_id + '.txt'
if storage.exists(file_key):
data = storage.load_once(file_key)
if data:
data = json.loads(data.decode('utf-8'))
else:
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
result = firecrawl_app.check_crawl_status(job_id)
if result.get('status') != 'completed':
raise ValueError('Crawl job is not completed')
data = result.get('data')
if data:
for item in data:
if item.get('source_url') == url:
return item
return None
else:
raise ValueError('Invalid provider')
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
params = {
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
result = firecrawl_app.scrape_url(url, params)
return result
else:
raise ValueError('Invalid provider')

View File

@ -11,7 +11,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.source import DataSourceBinding
from models.source import DataSourceOauthBinding
@shared_task(queue='dataset')
@ -43,12 +43,12 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info['notion_page_id']
page_type = data_source_info['type']
page_edited_time = data_source_info['last_edited_time']
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == document.tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:

View File

@ -0,0 +1,90 @@
import datetime
import logging
import time
import click
from celery import shared_task
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@shared_task(queue='dataset')
def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
Async process document
:param dataset_id:
:param document_id:
Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id)
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
"your subscription.")
except Exception as e:
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if document:
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
return
logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
try:
if document:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = 'error'
document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg='yellow'))
redis_client.delete(sync_indexing_cache_key)
pass
end_at = time.perf_counter()
logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))

View File

@ -0,0 +1,33 @@
import os
from unittest import mock
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.models.document import Document
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker):
url = "https://firecrawl.dev"
api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
base_url = 'https://api.firecrawl.dev'
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=base_url)
params = {
'crawlerOptions': {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"maxDepth": 1,
"limit": 1,
'returnOnlyUrls': False,
}
}
mocked_firecrawl = {
"jobId": "test",
}
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
print(job_id)
assert isinstance(job_id, str)

View File