import React, { useState, useEffect } from 'react'; // import { useNavigate } from 'react-router-dom'; import { TrainFront, Lock, AlertCircle } from 'lucide-react'; import { SearchableSelect } from './SearchableSelect'; import { ScheduleModal } from './ScheduleModal'; import { router } from '@inertiajs/react'; import { useInertiaForm } from 'use-inertia-form'; import { usePage } from '@inertiajs/react'; import type { Dataset } from '../types'; interface ModelFormProps { initialData?: { id: number; name: string; modelType: string; datasetId: number; task: string; objective?: string; metrics?: string[]; retraining_job?: { frequency: string; at: { hour: number; day_of_week?: string; day_of_month?: number; }; batch_mode?: string; batch_size?: number; batch_overlap?: number; batch_key?: string; tuning_frequency?: string; active: boolean; metric?: string; threshold?: number; tuner_config?: { n_trials: number; objective: string; config: Record; }; tuning_enabled?: boolean; }; }; datasets: Array; constants: { tasks: { value: string; label: string }[]; objectives: Record; metrics: Record; timezone: string; retraining_job_constants: any; tuner_job_constants: any; }; isEditing?: boolean; errors?: any; } const ErrorDisplay = ({ error }: { error?: string }) => ( error ? (
{error}
) : null ); export function ModelForm({ initialData, datasets, constants, isEditing, errors: initialErrors }: ModelFormProps) { const { rootPath } = usePage().props; const [showScheduleModal, setShowScheduleModal] = useState(false); const [isDataSet, setIsDataSet] = useState(false); const form = useInertiaForm({ model: { id: initialData?.id, name: initialData?.name || '', model_type: initialData?.model_type || 'xgboost', dataset_id: initialData?.dataset_id || '', task: initialData?.task || 'classification', objective: initialData?.objective || 'binary:logistic', metrics: initialData?.metrics || ['accuracy'], retraining_job_attributes: initialData?.retraining_job ? { id: initialData.retraining_job.id, frequency: initialData.retraining_job.frequency, tuning_frequency: initialData.retraining_job.tuning_frequency || 'month', batch_mode: initialData.retraining_job.batch_mode, batch_size: initialData.retraining_job.batch_size, batch_overlap: initialData.retraining_job.batch_overlap, batch_key: initialData.retraining_job.batch_key, at: { hour: initialData.retraining_job.at?.hour ?? 2, day_of_week: initialData.retraining_job.at?.day_of_week ?? 1, day_of_month: initialData.retraining_job.at?.day_of_month ?? 1 }, active: initialData.retraining_job.active, metric: initialData.retraining_job.metric, threshold: initialData.retraining_job.threshold, tuner_config: initialData.retraining_job.tuner_config, tuning_enabled: initialData.retraining_job.tuning_enabled || false, } : undefined } }); const { data, setData, post, patch, processing, errors: formErrors } = form; const errors = { ...initialErrors, ...formErrors }; const objectives: { value: string; label: string; description?: string }[] = constants.objectives[data.model.model_type]?.[data.model.task] || []; useEffect(() => { // Only set default metrics if none were provided from the backend if (!initialData?.metrics) { const availableMetrics = constants.metrics[data.model.task]?.map(metric => metric.value) || []; setData({ ...data, model: { ...data.model, objective: data.model.task === 'classification' ? 'binary:logistic' : 'reg:squarederror', metrics: availableMetrics } }); } else { setData({ ...data, model: { ...data.model, objective: data.model.task === 'classification' ? 'binary:logistic' : 'reg:squarederror' } }); } }, [data.model.task]); useEffect(() => { if (isDataSet) { save(); setIsDataSet(false); // Reset the flag } }, [isDataSet]); const handleScheduleSave = (scheduleData: any) => { setData({ ...data, model: { ...data.model, retraining_job_attributes: scheduleData.retraining_job_attributes } }); setIsDataSet(true); }; const save = () => { if (data.model.retraining_job_attributes) { const at: any = { hour: data.model.retraining_job_attributes.at.hour }; // Only include relevant date attributes based on frequency switch (data.model.retraining_job_attributes.frequency) { case 'day': // For daily frequency, only include hour break; case 'week': // For weekly frequency, include hour and day_of_week at.day_of_week = data.model.retraining_job_attributes.at.day_of_week; break; case 'month': // For monthly frequency, include hour and day_of_month at.day_of_month = data.model.retraining_job_attributes.at.day_of_month; break; } // Update the form data with the cleaned at object setData('model.retraining_job_attributes.at', at); } if (data.model.id) { patch(`${rootPath}/models/${data.model.id}`, { onSuccess: () => { router.visit(`${rootPath}/models`); }, }); } else { post(`${rootPath}/models`, { onSuccess: () => { router.visit(`${rootPath}/models`); }, }); } } const handleSubmit = (e: React.FormEvent) => { e.preventDefault(); save(); }; console.log(data.model) const selectedDataset = datasets.find(d => d.id === data.model.dataset_id); const filteredTunerJobConstants = constants.tuner_job_constants[data.model.model_type] || {}; return (

Model Configuration

setData('model.name', e.target.value)} className="block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500 py-2 px-4 shadow-sm border-gray-300 border" />
setData('model.model_type', value as string)} placeholder="Select model type" />
{isEditing ? (
{selectedDataset?.name}
) : ( ({ value: dataset.id, label: dataset.name, description: `${dataset.num_rows.toLocaleString()} rows` }))} value={data.model.dataset_id} onChange={(value) => setData('model.dataset_id', value)} placeholder="Select dataset" /> )}
setData('model.task', value as string)} placeholder="Select task" />
setData('model.objective', value as string)} placeholder="Select objective" />
{constants.metrics[data.model.task]?.map(metric => ( ))}
{data.model.retraining_job_attributes && data.model.retraining_job_attributes.batch_mode && ( <>
setData('model', { ...data.model, retraining_job_attributes: { ...data.model.retraining_job_attributes, batch_key: value } })} options={selectedDataset?.columns?.map(column => ({ value: column.name, label: column.name })) || []} placeholder="Select a column for batch key" />
)}
setShowScheduleModal(false)} onSave={handleScheduleSave} initialData={{ task: data.model.task, metrics: data.model.metrics, modelType: data.model.model_type, dataset: selectedDataset, retraining_job: data.model.retraining_job_attributes }} tunerJobConstants={filteredTunerJobConstants} timezone={constants.timezone} retrainingJobConstants={constants.retraining_job_constants} /> ); }