Skip to content

Commit

Permalink
[RISCV] Custom legalize vp.merge for mask vectors.
Browse files Browse the repository at this point in the history
The default legalization uses vmslt with a vector of XLen to compute
a mask. This doesn't work if the type isn't legal. For fixed vectors
it will scalarize. For scalable vectors it crashes the compiler.

This patch uses an alternate strategy that promotes the i1 vector
to an i8 vector and does the merge. I don't claim this to be the
best lowering. I wrote it almost 3 years ago when a crash was
reported in our downstream.

Fixes llvm#120405.
  • Loading branch information
topperc committed Dec 19, 2024
1 parent 60a2f32 commit 2e82183
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 14 deletions.
72 changes: 68 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,9 +758,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
Custom);

setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(
{ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
Expand);
setOperationAction({ISD::SELECT_CC, ISD::VSELECT, ISD::VP_SELECT}, VT,
Expand);
setOperationAction(ISD::VP_MERGE, VT, Custom);

setOperationAction({ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF}, VT,
Custom);
Expand Down Expand Up @@ -1237,6 +1237,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::VP_SETCC, ISD::VP_TRUNCATE},
VT, Custom);

setOperationAction(ISD::VP_MERGE, VT, Custom);

setOperationAction(ISD::EXPERIMENTAL_VP_SPLICE, VT, Custom);
setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom);
continue;
Expand Down Expand Up @@ -7492,8 +7494,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerSET_ROUNDING(Op, DAG);
case ISD::EH_DWARF_CFA:
return lowerEH_DWARF_CFA(Op, DAG);
case ISD::VP_SELECT:
case ISD::VP_MERGE:
if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
return lowerVPMergeMask(Op, DAG);
[[fallthrough]];
case ISD::VP_SELECT:
case ISD::VP_ADD:
case ISD::VP_SUB:
case ISD::VP_MUL:
Expand Down Expand Up @@ -12078,6 +12083,65 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op,
return convertFromScalableVector(VT, Result, DAG, Subtarget);
}

SDValue RISCVTargetLowering::lowerVPMergeMask(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
MVT XLenVT = Subtarget.getXLenVT();

SDValue Mask = Op.getOperand(0);
SDValue TrueVal = Op.getOperand(1);
SDValue FalseVal = Op.getOperand(2);
SDValue VL = Op.getOperand(3);

// Use default legalization if a vector of EVL type would be legal.
EVT EVLVecVT = EVT::getVectorVT(*DAG.getContext(), VL.getValueType(),
VT.getVectorElementCount());
if (isTypeLegal(EVLVecVT))
return SDValue();

MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VT);
Mask = convertToScalableVector(ContainerVT, Mask, DAG, Subtarget);
TrueVal = convertToScalableVector(ContainerVT, TrueVal, DAG, Subtarget);
FalseVal = convertToScalableVector(ContainerVT, FalseVal, DAG, Subtarget);
}

// Promote to a vector of i8.
MVT PromotedVT = ContainerVT.changeVectorElementType(MVT::i8);

// Promote TrueVal and FalseVal using VLMax.
// FIXME: Is there a better way to do this?
SDValue VLMax = DAG.getRegister(RISCV::X0, XLenVT);
SDValue SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, PromotedVT,
DAG.getUNDEF(PromotedVT),
DAG.getConstant(1, DL, XLenVT), VLMax);
SDValue SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, PromotedVT,
DAG.getUNDEF(PromotedVT),
DAG.getConstant(0, DL, XLenVT), VLMax);
TrueVal = DAG.getNode(RISCVISD::VMERGE_VL, DL, PromotedVT, TrueVal, SplatOne,
SplatZero, DAG.getUNDEF(PromotedVT), VL);
// Any element past VL uses FalseVal, so use VLMax
FalseVal = DAG.getNode(RISCVISD::VMERGE_VL, DL, PromotedVT, FalseVal,
SplatOne, SplatZero, DAG.getUNDEF(PromotedVT), VLMax);

// VP_MERGE the two promoted values.
SDValue VPMerge = DAG.getNode(RISCVISD::VMERGE_VL, DL, PromotedVT, Mask,
TrueVal, FalseVal, FalseVal, VL);

// Convert back to mask.
SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL);
SDValue Result = DAG.getNode(
RISCVISD::SETCC_VL, DL, ContainerVT,
{VPMerge, DAG.getConstant(0, DL, PromotedVT), DAG.getCondCode(ISD::SETNE),
DAG.getUNDEF(getMaskTypeFor(ContainerVT)), TrueMask, VLMax});

if (VT.isFixedLengthVector())
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
return Result;
}

SDValue
RISCVTargetLowering::lowerVPSpliceExperimental(SDValue Op,
SelectionDAG &DAG) const {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPSetCCMaskOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPMergeMask(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPSplatExperimental(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPSpliceExperimental(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPReverseExperimental(SDValue Op, SelectionDAG &DAG) const;
Expand Down
184 changes: 180 additions & 4 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,182 @@ define <4 x i1> @vpmerge_vv_v4i1(<4 x i1> %va, <4 x i1> %vb, <4 x i1> %m, i32 ze
ret <4 x i1> %v
}

define <8 x i1> @vpmerge_vv_v8i1(<8 x i1> %va, <8 x i1> %vb, <8 x i1> %m, i32 zeroext %evl) {
; RV32-LABEL: vpmerge_vv_v8i1:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV32-NEXT: vid.v v10
; RV32-NEXT: vmsltu.vx v12, v10, a0
; RV32-NEXT: vmand.mm v9, v9, v12
; RV32-NEXT: vmandn.mm v8, v8, v9
; RV32-NEXT: vmand.mm v9, v0, v9
; RV32-NEXT: vmor.mm v0, v9, v8
; RV32-NEXT: ret
;
; RV64-LABEL: vpmerge_vv_v8i1:
; RV64: # %bb.0:
; RV64-NEXT: vsetivli zero, 8, e64, m4, ta, ma
; RV64-NEXT: vid.v v12
; RV64-NEXT: vmsltu.vx v10, v12, a0
; RV64-NEXT: vmand.mm v9, v9, v10
; RV64-NEXT: vmandn.mm v8, v8, v9
; RV64-NEXT: vmand.mm v9, v0, v9
; RV64-NEXT: vmor.mm v0, v9, v8
; RV64-NEXT: ret
;
; RV32ZVFHMIN-LABEL: vpmerge_vv_v8i1:
; RV32ZVFHMIN: # %bb.0:
; RV32ZVFHMIN-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV32ZVFHMIN-NEXT: vid.v v10
; RV32ZVFHMIN-NEXT: vmsltu.vx v12, v10, a0
; RV32ZVFHMIN-NEXT: vmand.mm v9, v9, v12
; RV32ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
; RV32ZVFHMIN-NEXT: vmand.mm v9, v0, v9
; RV32ZVFHMIN-NEXT: vmor.mm v0, v9, v8
; RV32ZVFHMIN-NEXT: ret
;
; RV64ZVFHMIN-LABEL: vpmerge_vv_v8i1:
; RV64ZVFHMIN: # %bb.0:
; RV64ZVFHMIN-NEXT: vsetivli zero, 8, e64, m4, ta, ma
; RV64ZVFHMIN-NEXT: vid.v v12
; RV64ZVFHMIN-NEXT: vmsltu.vx v10, v12, a0
; RV64ZVFHMIN-NEXT: vmand.mm v9, v9, v10
; RV64ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
; RV64ZVFHMIN-NEXT: vmand.mm v9, v0, v9
; RV64ZVFHMIN-NEXT: vmor.mm v0, v9, v8
; RV64ZVFHMIN-NEXT: ret
%v = call <8 x i1> @llvm.vp.merge.v8i1(<8 x i1> %m, <8 x i1> %va, <8 x i1> %vb, i32 %evl)
ret <8 x i1> %v
}

define <16 x i1> @vpmerge_vv_v16i1(<16 x i1> %va, <16 x i1> %vb, <16 x i1> %m, i32 zeroext %evl) {
; RV32-LABEL: vpmerge_vv_v16i1:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; RV32-NEXT: vid.v v12
; RV32-NEXT: vmsltu.vx v10, v12, a0
; RV32-NEXT: vmand.mm v9, v9, v10
; RV32-NEXT: vmandn.mm v8, v8, v9
; RV32-NEXT: vmand.mm v9, v0, v9
; RV32-NEXT: vmor.mm v0, v9, v8
; RV32-NEXT: ret
;
; RV64-LABEL: vpmerge_vv_v16i1:
; RV64: # %bb.0:
; RV64-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; RV64-NEXT: vid.v v16
; RV64-NEXT: vmsltu.vx v10, v16, a0
; RV64-NEXT: vmand.mm v9, v9, v10
; RV64-NEXT: vmandn.mm v8, v8, v9
; RV64-NEXT: vmand.mm v9, v0, v9
; RV64-NEXT: vmor.mm v0, v9, v8
; RV64-NEXT: ret
;
; RV32ZVFHMIN-LABEL: vpmerge_vv_v16i1:
; RV32ZVFHMIN: # %bb.0:
; RV32ZVFHMIN-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; RV32ZVFHMIN-NEXT: vid.v v12
; RV32ZVFHMIN-NEXT: vmsltu.vx v10, v12, a0
; RV32ZVFHMIN-NEXT: vmand.mm v9, v9, v10
; RV32ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
; RV32ZVFHMIN-NEXT: vmand.mm v9, v0, v9
; RV32ZVFHMIN-NEXT: vmor.mm v0, v9, v8
; RV32ZVFHMIN-NEXT: ret
;
; RV64ZVFHMIN-LABEL: vpmerge_vv_v16i1:
; RV64ZVFHMIN: # %bb.0:
; RV64ZVFHMIN-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; RV64ZVFHMIN-NEXT: vid.v v16
; RV64ZVFHMIN-NEXT: vmsltu.vx v10, v16, a0
; RV64ZVFHMIN-NEXT: vmand.mm v9, v9, v10
; RV64ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
; RV64ZVFHMIN-NEXT: vmand.mm v9, v0, v9
; RV64ZVFHMIN-NEXT: vmor.mm v0, v9, v8
; RV64ZVFHMIN-NEXT: ret
%v = call <16 x i1> @llvm.vp.merge.v16i1(<16 x i1> %m, <16 x i1> %va, <16 x i1> %vb, i32 %evl)
ret <16 x i1> %v
}

define <32 x i1> @vpmerge_vv_v32i1(<32 x i1> %va, <32 x i1> %vb, <32 x i1> %m, i32 zeroext %evl) {
; RV32-LABEL: vpmerge_vv_v32i1:
; RV32: # %bb.0:
; RV32-NEXT: li a1, 32
; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; RV32-NEXT: vid.v v16
; RV32-NEXT: vmsltu.vx v10, v16, a0
; RV32-NEXT: vmand.mm v9, v9, v10
; RV32-NEXT: vmandn.mm v8, v8, v9
; RV32-NEXT: vmand.mm v9, v0, v9
; RV32-NEXT: vmor.mm v0, v9, v8
; RV32-NEXT: ret
;
; RV64-LABEL: vpmerge_vv_v32i1:
; RV64: # %bb.0:
; RV64-NEXT: vsetvli a1, zero, e8, m2, ta, ma
; RV64-NEXT: vmv.v.i v10, 0
; RV64-NEXT: vsetvli zero, a0, e8, m2, ta, ma
; RV64-NEXT: vmerge.vim v12, v10, 1, v0
; RV64-NEXT: vmv1r.v v0, v8
; RV64-NEXT: vsetvli a1, zero, e8, m2, ta, ma
; RV64-NEXT: vmerge.vim v10, v10, 1, v0
; RV64-NEXT: vmv1r.v v0, v9
; RV64-NEXT: vsetvli zero, a0, e8, m2, tu, ma
; RV64-NEXT: vmerge.vvm v10, v10, v12, v0
; RV64-NEXT: vsetvli a0, zero, e8, m2, ta, ma
; RV64-NEXT: vmsne.vi v0, v10, 0
; RV64-NEXT: ret
;
; RV32ZVFHMIN-LABEL: vpmerge_vv_v32i1:
; RV32ZVFHMIN: # %bb.0:
; RV32ZVFHMIN-NEXT: li a1, 32
; RV32ZVFHMIN-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; RV32ZVFHMIN-NEXT: vid.v v16
; RV32ZVFHMIN-NEXT: vmsltu.vx v10, v16, a0
; RV32ZVFHMIN-NEXT: vmand.mm v9, v9, v10
; RV32ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
; RV32ZVFHMIN-NEXT: vmand.mm v9, v0, v9
; RV32ZVFHMIN-NEXT: vmor.mm v0, v9, v8
; RV32ZVFHMIN-NEXT: ret
;
; RV64ZVFHMIN-LABEL: vpmerge_vv_v32i1:
; RV64ZVFHMIN: # %bb.0:
; RV64ZVFHMIN-NEXT: vsetvli a1, zero, e8, m2, ta, ma
; RV64ZVFHMIN-NEXT: vmv.v.i v10, 0
; RV64ZVFHMIN-NEXT: vsetvli zero, a0, e8, m2, ta, ma
; RV64ZVFHMIN-NEXT: vmerge.vim v12, v10, 1, v0
; RV64ZVFHMIN-NEXT: vmv1r.v v0, v8
; RV64ZVFHMIN-NEXT: vsetvli a1, zero, e8, m2, ta, ma
; RV64ZVFHMIN-NEXT: vmerge.vim v10, v10, 1, v0
; RV64ZVFHMIN-NEXT: vmv1r.v v0, v9
; RV64ZVFHMIN-NEXT: vsetvli zero, a0, e8, m2, tu, ma
; RV64ZVFHMIN-NEXT: vmerge.vvm v10, v10, v12, v0
; RV64ZVFHMIN-NEXT: vsetvli a0, zero, e8, m2, ta, ma
; RV64ZVFHMIN-NEXT: vmsne.vi v0, v10, 0
; RV64ZVFHMIN-NEXT: ret
%v = call <32 x i1> @llvm.vp.merge.v32i1(<32 x i1> %m, <32 x i1> %va, <32 x i1> %vb, i32 %evl)
ret <32 x i1> %v
}

define <64 x i1> @vpmerge_vv_v64i1(<64 x i1> %va, <64 x i1> %vb, <64 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpmerge_vv_v64i1:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e8, m4, ta, ma
; CHECK-NEXT: vmv.v.i v12, 0
; CHECK-NEXT: vsetvli zero, a0, e8, m4, ta, ma
; CHECK-NEXT: vmerge.vim v16, v12, 1, v0
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: vsetvli a1, zero, e8, m4, ta, ma
; CHECK-NEXT: vmerge.vim v12, v12, 1, v0
; CHECK-NEXT: vmv1r.v v0, v9
; CHECK-NEXT: vsetvli zero, a0, e8, m4, tu, ma
; CHECK-NEXT: vmerge.vvm v12, v12, v16, v0
; CHECK-NEXT: vsetvli a0, zero, e8, m4, ta, ma
; CHECK-NEXT: vmsne.vi v0, v12, 0
; CHECK-NEXT: ret
%v = call <64 x i1> @llvm.vp.merge.v64i1(<64 x i1> %m, <64 x i1> %va, <64 x i1> %vb, i32 %evl)
ret <64 x i1> %v
}

declare <2 x i8> @llvm.vp.merge.v2i8(<2 x i1>, <2 x i8>, <2 x i8>, i32)

define <2 x i8> @vpmerge_vv_v2i8(<2 x i8> %va, <2 x i8> %vb, <2 x i1> %m, i32 zeroext %evl) {
Expand Down Expand Up @@ -1188,10 +1364,10 @@ define <32 x double> @vpmerge_vv_v32f64(<32 x double> %va, <32 x double> %vb, <3
; CHECK-NEXT: vle64.v v8, (a0)
; CHECK-NEXT: li a1, 16
; CHECK-NEXT: mv a0, a2
; CHECK-NEXT: bltu a2, a1, .LBB79_2
; CHECK-NEXT: bltu a2, a1, .LBB83_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: li a0, 16
; CHECK-NEXT: .LBB79_2:
; CHECK-NEXT: .LBB83_2:
; CHECK-NEXT: vsetvli zero, a0, e64, m8, tu, ma
; CHECK-NEXT: vmerge.vvm v8, v8, v16, v0
; CHECK-NEXT: addi a0, a2, -16
Expand Down Expand Up @@ -1221,10 +1397,10 @@ define <32 x double> @vpmerge_vf_v32f64(double %a, <32 x double> %vb, <32 x i1>
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 16
; CHECK-NEXT: mv a1, a0
; CHECK-NEXT: bltu a0, a2, .LBB80_2
; CHECK-NEXT: bltu a0, a2, .LBB84_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: li a1, 16
; CHECK-NEXT: .LBB80_2:
; CHECK-NEXT: .LBB84_2:
; CHECK-NEXT: vsetvli zero, a1, e64, m8, tu, ma
; CHECK-NEXT: vfmerge.vfm v8, v8, fa0, v0
; CHECK-NEXT: addi a1, a0, -16
Expand Down
Loading

0 comments on commit 2e82183

Please sign in to comment.