diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9797,6 +9797,8 @@ MaskedStoreSDNode *MST = cast(N); SDValue Mask = MST->getMask(); SDValue Chain = MST->getChain(); + SDValue Value = MST->getValue(); + SDValue Ptr = MST->getBasePtr(); SDLoc DL(N); // Zap masked stores with a zero mask. @@ -9815,8 +9817,53 @@ if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N)) return SDValue(N, 0); - return SDValue(); -} +if (MST->isTruncatingStore() && MST->isUnindexed() && + Value.getValueType().isInteger() && + (!isa(Value) || + !cast(Value)->isOpaque())) { + APInt TruncDemandedBits = + APInt::getLowBitsSet(Value.getScalarValueSizeInBits(), + MST->getMemoryVT().getScalarSizeInBits()); + + // See if we can simplify the input to this truncstore with knowledge that + // only the low bits are being used. For example: + // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8" + AddToWorklist(Value.getNode()); + if (SDValue Shorter = DAG.GetDemandedBits(Value, TruncDemandedBits)) + return DAG.getMaskedStore(Chain, SDLoc(N), Shorter, Ptr, + MST->getOffset(), MST->getMask(), + MST->getMemoryVT(), MST->getMemOperand(), + MST->getAddressingMode(), /*IsTruncating=*/true); + + // Otherwise, see if we can simplify the operation with + // SimplifyDemandedBits, which only works if the value has a single use. + if (SimplifyDemandedBits(Value, TruncDemandedBits)) { + // Re-visit the store if anything changed and the store hasn't been merged + // with another node (N is deleted) SimplifyDemandedBits will add Value's + // node back to the worklist if necessary, but we also need to re-visit + // the Store node itself. + if (N->getOpcode() != ISD::DELETED_NODE) + AddToWorklist(N); + return SDValue(N, 0); + } + } + + + // If this is an FP_ROUND or TRUNC followed by a store, fold this into a + // truncating store. We can do this even if this is already a truncstore. + if ((Value.getOpcode() == ISD::FP_ROUND || + Value.getOpcode() == ISD::TRUNCATE) && + Value.getNode()->hasOneUse() && MST->isUnindexed() && + TLI.canCombineTruncStore(Value.getOperand(0).getValueType(), + MST->getMemoryVT(), LegalOperations)) { + return DAG.getMaskedStore( + Chain, SDLoc(N), Value.getOperand(0), Ptr, MST->getOffset(), + MST->getMask(), MST->getMemoryVT(), MST->getMemOperand(), + MST->getAddressingMode(), /*IsTruncating=*/true); + } + + return SDValue(); + } SDValue DAGCombiner::visitMGATHER(SDNode *N) { MaskedGatherSDNode *MGT = cast(N); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -18152,16 +18152,19 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE( SDValue Op, SelectionDAG &DAG) const { auto Store = cast(Op); - - if (Store->isTruncatingStore()) - return SDValue(); + SDValue Mask = Store->getMask(); SDLoc DL(Op); EVT VT = Store->getValue().getValueType(); EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + if (Store->isTruncatingStore()) { + Mask = DAG.getNode( + ISD::ZERO_EXTEND, DL, + VT.changeVectorElementType(ContainerVT.getVectorElementType()), Mask); + } + Mask = convertFixedMaskToScalableVector(Mask, DAG); auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue()); - SDValue Mask = convertFixedMaskToScalableVector(Store->getMask(), DAG); return DAG.getMaskedStore( Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(), diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-stores.ll @@ -151,14 +151,7 @@ ; VBITS_GE_512-NEXT: ld1d { [[Z0:z[0-9]+]].d }, p0/z, [x0] ; VBITS_GE_512-NEXT: ld1d { [[Z1:z[0-9]+]].d }, p0/z, [x1] ; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].d, p[[P0]]/z, [[Z0]].d, [[Z1]].d -; VBITS_GE_512-DAG: uzp1 [[Z1]].s, [[Z1]].s, [[Z1]].s -; VBITS_GE_512-DAG: uzp1 [[Z1]].h, [[Z1]].h, [[Z1]].h -; VBITS_GE_512-DAG: uzp1 [[Z1]].b, [[Z1]].b, [[Z1]].b -; VBITS_GE_512-DAG: cmpne p[[P2:[0-9]+]].b, p{{[0-9]+}}/z, [[Z1]].b, #0 -; VBITS_GE_512-DAG: uzp1 [[Z0]].s, [[Z0]].s, [[Z0]].s -; VBITS_GE_512-DAG: uzp1 [[Z0]].h, [[Z0]].h, [[Z0]].h -; VBITS_GE_512-DAG: uzp1 [[Z0]].b, [[Z0]].b, [[Z0]].b -; VBITS_GE_512-NEXT: st1b { [[Z0]].b }, p[[P2]], [x{{[0-9]+}}] +; VBITS_GE_512-NEXT: st1b { [[Z0]].d }, p[[P1]], [x{{[0-9]+}}] ; VBITS_GE_512-NEXT: ret %a = load <8 x i64>, <8 x i64>* %ap %b = load <8 x i64>, <8 x i64>* %bp @@ -173,15 +166,8 @@ ; VBITS_GE_512: ptrue p[[P0:[0-9]+]].d, vl8 ; VBITS_GE_512-NEXT: ld1d { [[Z0:z[0-9]+]].d }, p0/z, [x0] ; VBITS_GE_512-NEXT: ld1d { [[Z1:z[0-9]+]].d }, p0/z, [x1] -; VBITS_GE_512-DAG: ptrue p{{[0-9]+}}.h, vl8 -; VBITS_GE_512-DAG: cmpeq p[[P1:[0-9]+]].d, p[[P0]]/z, [[Z0]].d, [[Z1]].d -; VBITS_GE_512-NEXT: mov [[Z1]].d, p[[P0]]/z, #-1 -; VBITS_GE_512-DAG: uzp1 [[Z1]].s, [[Z1]].s, [[Z1]].s -; VBITS_GE_512-DAG: uzp1 [[Z1]].h, [[Z1]].h, [[Z1]].h -; VBITS_GE_512-DAG: cmpne p[[P2:[0-9]+]].h, p{{[0-9]+}}/z, [[Z1]].h, #0 -; VBITS_GE_512-DAG: uzp1 [[Z0]].s, [[Z0]].s, [[Z0]].s -; VBITS_GE_512-DAG: uzp1 [[Z0]].h, [[Z0]].h, [[Z0]].h -; VBITS_GE_512-NEXT: st1h { [[Z0]].h }, p[[P2]], [x{{[0-9]+}}] +; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].d, p[[P0]]/z, [[Z0]].d, [[Z1]].d +; VBITS_GE_512-NEXT: st1h { [[Z0]].d }, p[[P1]], [x{{[0-9]+}}] ; VBITS_GE_512-NEXT: ret %a = load <8 x i64>, <8 x i64>* %ap %b = load <8 x i64>, <8 x i64>* %bp @@ -196,13 +182,8 @@ ; VBITS_GE_512: ptrue p[[P0:[0-9]+]].d, vl8 ; VBITS_GE_512-NEXT: ld1d { [[Z0:z[0-9]+]].d }, p0/z, [x0] ; VBITS_GE_512-NEXT: ld1d { [[Z1:z[0-9]+]].d }, p0/z, [x1] -; VBITS_GE_512-DAG: ptrue p{{[0-9]+}}.s, vl8 -; VBITS_GE_512-DAG: cmpeq p[[P1:[0-9]+]].d, p[[P0]]/z, [[Z0]].d, [[Z1]].d -; VBITS_GE_512-NEXT: mov [[Z1]].d, p[[P0]]/z, #-1 -; VBITS_GE_512-DAG: uzp1 [[Z1]].s, [[Z1]].s, [[Z1]].s -; VBITS_GE_512-DAG: cmpne p[[P2:[0-9]+]].s, p{{[0-9]+}}/z, [[Z1]].s, #0 -; VBITS_GE_512-DAG: uzp1 [[Z0]].s, [[Z0]].s, [[Z0]].s -; VBITS_GE_512-NEXT: st1w { [[Z0]].s }, p[[P2]], [x{{[0-9]+}}] +; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].d, p[[P0]]/z, [[Z0]].d, [[Z1]].d +; VBITS_GE_512-NEXT: st1w { [[Z0]].d }, p[[P1]], [x{{[0-9]+}}] ; VBITS_GE_512-NEXT: ret %a = load <8 x i64>, <8 x i64>* %ap %b = load <8 x i64>, <8 x i64>* %bp @@ -217,15 +198,8 @@ ; VBITS_GE_512: ptrue p[[P0:[0-9]+]].s, vl16 ; VBITS_GE_512-NEXT: ld1w { [[Z0:z[0-9]+]].s }, p0/z, [x0] ; VBITS_GE_512-NEXT: ld1w { [[Z1:z[0-9]+]].s }, p0/z, [x1] -; VBITS_GE_512-DAG: ptrue p{{[0-9]+}}.b, vl16 -; VBITS_GE_512-DAG: cmpeq p[[P1:[0-9]+]].s, p[[P0]]/z, [[Z0]].s, [[Z1]].s -; VBITS_GE_512-NEXT: mov [[Z1]].s, p[[P0]]/z, #-1 -; VBITS_GE_512-DAG: uzp1 [[Z1]].h, [[Z1]].h, [[Z1]].h -; VBITS_GE_512-DAG: uzp1 [[Z1]].b, [[Z1]].b, [[Z1]].b -; VBITS_GE_512-DAG: cmpne p[[P2:[0-9]+]].b, p{{[0-9]+}}/z, [[Z1]].b, #0 -; VBITS_GE_512-DAG: uzp1 [[Z0]].h, [[Z0]].h, [[Z0]].h -; VBITS_GE_512-DAG: uzp1 [[Z0]].b, [[Z0]].b, [[Z0]].b -; VBITS_GE_512-NEXT: st1b { [[Z0]].b }, p[[P2]], [x{{[0-9]+}}] +; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].s, p[[P0]]/z, [[Z0]].s, [[Z1]].s +; VBITS_GE_512-NEXT: st1b { [[Z0]].s }, p[[P1]], [x{{[0-9]+}}] ; VBITS_GE_512-NEXT: ret %a = load <16 x i32>, <16 x i32>* %ap %b = load <16 x i32>, <16 x i32>* %bp @@ -240,13 +214,8 @@ ; VBITS_GE_512: ptrue p[[P0:[0-9]+]].s, vl16 ; VBITS_GE_512-NEXT: ld1w { [[Z0:z[0-9]+]].s }, p0/z, [x0] ; VBITS_GE_512-NEXT: ld1w { [[Z1:z[0-9]+]].s }, p0/z, [x1] -; VBITS_GE_512-DAG: ptrue p{{[0-9]+}}.h, vl16 -; VBITS_GE_512-DAG: cmpeq p[[P1:[0-9]+]].s, p[[P0]]/z, [[Z0]].s, [[Z1]].s -; VBITS_GE_512-NEXT: mov [[Z1]].s, p[[P0]]/z, #-1 -; VBITS_GE_512-DAG: uzp1 [[Z1]].h, [[Z1]].h, [[Z1]].h -; VBITS_GE_512-DAG: cmpne p[[P2:[0-9]+]].h, p{{[0-9]+}}/z, [[Z1]].h, #0 -; VBITS_GE_512-DAG: uzp1 [[Z0]].h, [[Z0]].h, [[Z0]].h -; VBITS_GE_512-NEXT: st1h { [[Z0]].h }, p[[P2]], [x{{[0-9]+}}] +; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].s, p[[P0]]/z, [[Z0]].s, [[Z1]].s +; VBITS_GE_512-NEXT: st1h { [[Z0]].s }, p[[P1]], [x{{[0-9]+}}] ; VBITS_GE_512-NEXT: ret %a = load <16 x i32>, <16 x i32>* %ap %b = load <16 x i32>, <16 x i32>* %bp @@ -261,13 +230,8 @@ ; VBITS_GE_512: ptrue p[[P0:[0-9]+]].h, vl32 ; VBITS_GE_512-NEXT: ld1h { [[Z0:z[0-9]+]].h }, p0/z, [x0] ; VBITS_GE_512-NEXT: ld1h { [[Z1:z[0-9]+]].h }, p0/z, [x1] -; VBITS_GE_512-DAG: ptrue p{{[0-9]+}}.b, vl32 -; VBITS_GE_512-DAG: cmpeq p[[P1:[0-9]+]].h, p[[P0]]/z, [[Z0]].h, [[Z1]].h -; VBITS_GE_512-NEXT: mov [[Z1]].h, p[[P0]]/z, #-1 -; VBITS_GE_512-DAG: uzp1 [[Z1]].b, [[Z1]].b, [[Z1]].b -; VBITS_GE_512-DAG: cmpne p[[P2:[0-9]+]].b, p{{[0-9]+}}/z, [[Z1]].b, #0 -; VBITS_GE_512-DAG: uzp1 [[Z0]].b, [[Z0]].b, [[Z0]].b -; VBITS_GE_512-NEXT: st1b { [[Z0]].b }, p[[P2]], [x{{[0-9]+}}] +; VBITS_GE_512-NEXT: cmpeq p[[P1:[0-9]+]].h, p[[P0]]/z, [[Z0]].h, [[Z1]].h +; VBITS_GE_512-NEXT: st1b { [[Z0]].h }, p[[P1]], [x{{[0-9]+}}] ; VBITS_GE_512-NEXT: ret %a = load <32 x i16>, <32 x i16>* %ap %b = load <32 x i16>, <32 x i16>* %bp