Skip to content

Commit

Permalink
feat: check usb, pci and vgpu resource name
Browse files Browse the repository at this point in the history
Signed-off-by: Jack Yu <[email protected]>
  • Loading branch information
Yu-Jack committed Jan 15, 2025
1 parent ebcf598 commit eacf768
Show file tree
Hide file tree
Showing 10 changed files with 478 additions and 30 deletions.
23 changes: 23 additions & 0 deletions pkg/util/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"fmt"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"

"github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1"
)

const (
Expand Down Expand Up @@ -77,3 +80,23 @@ func GetVFList(pfDir string) (vfList []string, err error) {
}
return
}

func VGPUDeviceByResourceName(obj *v1beta1.VGPUDevice) ([]string, error) {
return []string{
GeneratevGPUDeviceName(obj.Status.ConfiguredVGPUTypeName),
}, nil
}

func GeneratevGPUDeviceName(deviceName string) string {
deviceName = strings.TrimSpace(deviceName)
deviceName = strings.ToUpper(deviceName)
deviceName = strings.Replace(deviceName, "/", "_", -1)
deviceName = strings.Replace(deviceName, ".", "_", -1)
//deviceName = strings.Replace(deviceName, "-", "_", -1)
reg, _ := regexp.Compile(`\s+`)
deviceName = reg.ReplaceAllString(deviceName, "_")
// Removes any char other than alphanumeric and underscore
reg, _ = regexp.Compile(`^a-zA-Z0-9_-.]+`)
deviceName = reg.ReplaceAllString(deviceName, "")
return fmt.Sprintf("nvidia.com/%s", deviceName)
}
15 changes: 14 additions & 1 deletion pkg/util/fakeclients/pcidevices.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import (
)

const (
IommuGroupByNode = "pcidevice.harvesterhci.io/iommu-by-node"
IommuGroupByNode = "pcidevice.harvesterhci.io/iommu-by-node"
PCIDeviceByResourceName = "harvesterhcio.io/pcidevice-by-resource-name"
)

type PCIDevicesClient func() v1beta1.PCIDeviceInterface
Expand Down Expand Up @@ -80,6 +81,18 @@ func (p PCIDevicesCache) GetByIndex(indexName, key string) ([]*pcidevicev1beta1.
}
}
return resp, err
case PCIDeviceByResourceName:
list, err := p().List(context.TODO(), metav1.ListOptions{})
if err != nil {
return nil, err
}
var resp []*pcidevicev1beta1.PCIDevice
for i, v := range list.Items {
if key == v.Status.ResourceName {
resp = append(resp, &list.Items[i])
}
}
return resp, nil
default:
return nil, nil
}
Expand Down
22 changes: 20 additions & 2 deletions pkg/util/fakeclients/vgpudevice.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ import (
pcidevicev1beta1 "github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1"
"github.com/harvester/pcidevices/pkg/generated/clientset/versioned/typed/devices.harvesterhci.io/v1beta1"
pcidevicesv1beta1ctl "github.com/harvester/pcidevices/pkg/generated/controllers/devices.harvesterhci.io/v1beta1"
"github.com/harvester/pcidevices/pkg/util/common"
)

const vGPUDeviceByResourceName = "harvesterhci.io/vgpu-device-by-resource-name"

type VGPUDeviceClient func() v1beta1.VGPUDeviceInterface

func (s VGPUDeviceClient) Update(d *pcidevicev1beta1.VGPUDevice) (*pcidevicev1beta1.VGPUDevice, error) {
Expand Down Expand Up @@ -72,6 +75,21 @@ func (s VGPUDeviceCache) AddIndexer(_ string, _ pcidevicesv1beta1ctl.VGPUDeviceI
panic("implement me")
}

func (s VGPUDeviceCache) GetByIndex(_, _ string) ([]*pcidevicev1beta1.VGPUDevice, error) {
panic("implement me")
func (s VGPUDeviceCache) GetByIndex(index, key string) ([]*pcidevicev1beta1.VGPUDevice, error) {
switch index {
case vGPUDeviceByResourceName:
devices, err := s.List(labels.NewSelector())
if err != nil {
return nil, err
}
for _, device := range devices {
if common.GeneratevGPUDeviceName(device.Status.ConfiguredVGPUTypeName) == key {
return []*pcidevicev1beta1.VGPUDevice{device}, nil
}
}
return nil, nil
default:
}

return nil, nil
}
13 changes: 1 addition & 12 deletions pkg/util/gpuhelper/gpuhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strings"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -295,15 +294,5 @@ func evalPhysFn(devicePath string) (string, error) {
}

func GenerateDeviceName(deviceName string) string {
deviceName = strings.TrimSpace(deviceName)
deviceName = strings.ToUpper(deviceName)
deviceName = strings.Replace(deviceName, "/", "_", -1)
deviceName = strings.Replace(deviceName, ".", "_", -1)
//deviceName = strings.Replace(deviceName, "-", "_", -1)
reg, _ := regexp.Compile(`\s+`)
deviceName = reg.ReplaceAllString(deviceName, "_")
// Removes any char other than alphanumeric and underscore
reg, _ = regexp.Compile(`^a-zA-Z0-9_-.]+`)
deviceName = reg.ReplaceAllString(deviceName, "")
return fmt.Sprintf("nvidia.com/%s", deviceName)
return common.GeneratevGPUDeviceName(deviceName)
}
4 changes: 2 additions & 2 deletions pkg/webhook/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Clients struct {
CoreFactory *ctlcore.Factory
HarvesterFactory *ctlharvesterv1.Factory
KubevirtFactory *ctlkubevirtv1.Factory
PCIFactory *ctlpcidevices.Factory
DeviceFactory *ctlpcidevices.Factory
}

func NewClient(ctx context.Context, rest *rest.Config, threadiness int) (*Clients, error) {
Expand Down Expand Up @@ -70,6 +70,6 @@ func NewClient(ctx context.Context, rest *rest.Config, threadiness int) (*Client
HarvesterFactory: harvesterFactory,
KubevirtFactory: kubevirtFactory,
CoreFactory: coreFactory,
PCIFactory: pciFactory,
DeviceFactory: pciFactory,
}, nil
}
20 changes: 14 additions & 6 deletions pkg/webhook/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,32 @@ import (
kubevirtv1 "kubevirt.io/api/core/v1"

"github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1"
"github.com/harvester/pcidevices/pkg/util/common"
)

const (
VMByName = "harvesterhci.io/vm-by-name"
PCIDeviceByResourceName = "harvesterhcio.io/pcidevice-by-resource-name"
IommuGroupByNode = "pcidevice.harvesterhci.io/iommu-by-node"
VMByPCIDeviceClaim = "harvesterhci.io/vm-by-pcideviceclaim"
VMByVGPU = "harvesterhci.io/vm-by-vgpu"
VMByName = "harvesterhci.io/vm-by-name"
PCIDeviceByResourceName = "harvesterhcio.io/pcidevice-by-resource-name"
IommuGroupByNode = "pcidevice.harvesterhci.io/iommu-by-node"
USBDeviceByAddress = "pcidevice.harvesterhci.io/usb-device-by-address"
VMByPCIDeviceClaim = "harvesterhci.io/vm-by-pcideviceclaim"
VMByUSBDeviceClaim = "harvesterhci.io/vm-by-usbdeviceclaim"
VMByVGPU = "harvesterhci.io/vm-by-vgpu"
USBDeviceByResourceName = "harvesterhci.io/usbdevice-by-resource-name"
vGPUDeviceByResourceName = "harvesterhci.io/vgpu-device-by-resource-name"
)

func RegisterIndexers(clients *Clients) {
vmCache := clients.KubevirtFactory.Kubevirt().V1().VirtualMachine().Cache()
vmCache.AddIndexer(VMByName, vmByName)
vmCache.AddIndexer(VMByPCIDeviceClaim, vmByPCIDeviceClaim)
vmCache.AddIndexer(VMByVGPU, vmByVGPUDevice)
deviceCache := clients.PCIFactory.Devices().V1beta1().PCIDevice().Cache()
deviceCache := clients.DeviceFactory.Devices().V1beta1().PCIDevice().Cache()
deviceCache.AddIndexer(PCIDeviceByResourceName, pciDeviceByResourceName)
deviceCache.AddIndexer(IommuGroupByNode, iommuGroupByNodeName)

vgpuCache := clients.DeviceFactory.Devices().V1beta1().VGPUDevice().Cache()
vgpuCache.AddIndexer(vGPUDeviceByResourceName, common.VGPUDeviceByResourceName)
}

func vmByName(obj *kubevirtv1.VirtualMachine) ([]string, error) {
Expand Down
11 changes: 6 additions & 5 deletions pkg/webhook/mutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ import (

func Mutation(clients *Clients) (http.Handler, []types.Resource, error) {
mutators := []types.Mutator{
NewPodMutator(clients.PCIFactory.Devices().V1beta1().PCIDevice().Cache(),
NewPodMutator(clients.DeviceFactory.Devices().V1beta1().PCIDevice().Cache(),
clients.KubevirtFactory.Kubevirt().V1().VirtualMachine().Cache(),
clients.PCIFactory.Devices().V1beta1().VGPUDevice().Cache()),
NewPCIVMMutator(clients.PCIFactory.Devices().V1beta1().PCIDevice().Cache(),
clients.PCIFactory.Devices().V1beta1().PCIDeviceClaim().Cache(),
clients.PCIFactory.Devices().V1beta1().PCIDeviceClaim()),
clients.DeviceFactory.Devices().V1beta1().VGPUDevice().Cache()),
NewPCIVMMutator(clients.DeviceFactory.Devices().V1beta1().PCIDevice().Cache(),
clients.DeviceFactory.Devices().V1beta1().PCIDeviceClaim().Cache(),
clients.DeviceFactory.Devices().V1beta1().PCIDeviceClaim(),
),
}

router := webhook.NewRouter()
Expand Down
8 changes: 6 additions & 2 deletions pkg/webhook/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ import (

func Validation(clients *Clients) (http.Handler, []types.Resource, error) {
validators := []types.Validator{
NewSriovNetworkDeviceValidator(clients.PCIFactory.Devices().V1beta1().PCIDeviceClaim().Cache()),
NewPCIDeviceClaimValidator(clients.PCIFactory.Devices().V1beta1().PCIDevice().Cache(), clients.KubevirtFactory.Kubevirt().V1().VirtualMachine().Cache()),
NewSriovNetworkDeviceValidator(clients.DeviceFactory.Devices().V1beta1().PCIDeviceClaim().Cache()),
NewPCIDeviceClaimValidator(clients.DeviceFactory.Devices().V1beta1().PCIDevice().Cache(), clients.KubevirtFactory.Kubevirt().V1().VirtualMachine().Cache()),
NewVGPUValidator(clients.KubevirtFactory.Kubevirt().V1().VirtualMachine().Cache()),
NewSRIOVGPUValidator(clients.KubevirtFactory.Kubevirt().V1().VirtualMachine().Cache()),
NewDeviceHostValidation(
clients.DeviceFactory.Devices().V1beta1().PCIDevice().Cache(),
clients.DeviceFactory.Devices().V1beta1().VGPUDevice().Cache(),
),
}

router := webhook.NewRouter()
Expand Down
Loading

0 comments on commit eacf768

Please sign in to comment.