diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h @@ -100,8 +100,8 @@ unsigned getCacheLineSize() const override; unsigned getPrefetchDistance() const override; unsigned getMaxInterleaveFactor(unsigned VF); - InstructionCost vectorCostAdjustment(InstructionCost Cost, unsigned Opcode, - Type *Ty1, Type *Ty2); + InstructionCost vectorCostAdjustmentFactor(unsigned Opcode, Type *Ty1, + Type *Ty2); InstructionCost getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueKind Opd1Info = TTI::OK_AnyValue, diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -318,9 +318,20 @@ return PPCTTIImpl::getIntImmCost(Imm, Ty, CostKind); } +// Check if the current Type is an MMA vector type. Valid MMA types are +// v256i1 and v512i1 respectively. +static bool isMMAType(Type *Ty) { + return Ty->isVectorTy() && (Ty->getScalarSizeInBits() == 1) && + (Ty->getPrimitiveSizeInBits() > 128); +} + InstructionCost PPCTTIImpl::getUserCost(const User *U, ArrayRef Operands, TTI::TargetCostKind CostKind) { + // Set the max cost if an MMA type is present (v256i1, v512i1). + if (isMMAType(U->getType())) + return InstructionCost::getMax(); + // We already implement getCastInstrCost and getMemoryOpCost where we perform // the vector adjustment there. if (isa(U) || isa(U) || isa(U)) @@ -942,32 +953,39 @@ return 2; } -// Adjust the cost of vector instructions on targets which there is overlap -// between the vector and scalar units, thereby reducing the overall throughput -// of vector code wrt. scalar code. -InstructionCost PPCTTIImpl::vectorCostAdjustment(InstructionCost Cost, - unsigned Opcode, Type *Ty1, - Type *Ty2) { +// Returns a cost adjustment factor to adjust the cost of vector instructions +// on targets which there is overlap between the vector and scalar units, +// thereby reducing the overall throughput of vector code wrt. scalar code. +// An invalid instruction cost is returned if the type is an MMA vector type. +InstructionCost PPCTTIImpl::vectorCostAdjustmentFactor(unsigned Opcode, + Type *Ty1, Type *Ty2) { + // If the vector type is of an MMA type (v256i1, v512i1), an invalid + // instruction cost is returned. This is to signify to other cost computing + // functions to return the maximum instruction cost in order to prevent any + // opportunities for the optimizer to produce MMA types within the IR. + if (isMMAType(Ty1)) + return InstructionCost::getInvalid(); + if (!ST->vectorsUseTwoUnits() || !Ty1->isVectorTy()) - return Cost; + return InstructionCost(1); std::pair LT1 = TLI->getTypeLegalizationCost(DL, Ty1); // If type legalization involves splitting the vector, we don't want to // double the cost at every step - only the last step. if (LT1.first != 1 || !LT1.second.isVector()) - return Cost; + return InstructionCost(1); int ISD = TLI->InstructionOpcodeToISD(Opcode); if (TLI->isOperationExpand(ISD, LT1.second)) - return Cost; + return InstructionCost(1); if (Ty2) { std::pair LT2 = TLI->getTypeLegalizationCost(DL, Ty2); if (LT2.first != 1 || !LT2.second.isVector()) - return Cost; + return InstructionCost(1); } - return Cost * 2; + return InstructionCost(2); } InstructionCost PPCTTIImpl::getArithmeticInstrCost( @@ -977,6 +995,11 @@ TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode"); + + InstructionCost CostFactor = vectorCostAdjustmentFactor(Opcode, Ty, nullptr); + if (!CostFactor.isValid()) + return InstructionCost::getMax(); + // TODO: Handle more cost kinds. if (CostKind != TTI::TCK_RecipThroughput) return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, @@ -986,12 +1009,18 @@ // Fallback to the default implementation. InstructionCost Cost = BaseT::getArithmeticInstrCost( Opcode, Ty, CostKind, Op1Info, Op2Info, Opd1PropInfo, Opd2PropInfo); - return vectorCostAdjustment(Cost, Opcode, Ty, nullptr); + return Cost * CostFactor; } InstructionCost PPCTTIImpl::getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, ArrayRef Mask, int Index, Type *SubTp) { + + InstructionCost CostFactor = + vectorCostAdjustmentFactor(Instruction::ShuffleVector, Tp, nullptr); + if (!CostFactor.isValid()) + return InstructionCost::getMax(); + // Legalize the type. std::pair LT = TLI->getTypeLegalizationCost(DL, Tp); @@ -1000,8 +1029,7 @@ // instruction). We need one such shuffle instruction for each actual // register (this is not true for arbitrary shuffles, but is true for the // structured types of shuffles covered by TTI::ShuffleKind). - return vectorCostAdjustment(LT.first, Instruction::ShuffleVector, Tp, - nullptr); + return LT.first * CostFactor; } InstructionCost PPCTTIImpl::getCFInstrCost(unsigned Opcode, @@ -1020,9 +1048,13 @@ const Instruction *I) { assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode"); + InstructionCost CostFactor = vectorCostAdjustmentFactor(Opcode, Dst, Src); + if (!CostFactor.isValid()) + return InstructionCost::getMax(); + InstructionCost Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); - Cost = vectorCostAdjustment(Cost, Opcode, Dst, Src); + Cost *= CostFactor; // TODO: Allow non-throughput costs that aren't binary. if (CostKind != TTI::TCK_RecipThroughput) return Cost == 0 ? 0 : 1; @@ -1034,12 +1066,17 @@ CmpInst::Predicate VecPred, TTI::TargetCostKind CostKind, const Instruction *I) { + InstructionCost CostFactor = + vectorCostAdjustmentFactor(Opcode, ValTy, nullptr); + if (!CostFactor.isValid()) + return InstructionCost::getMax(); + InstructionCost Cost = BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I); // TODO: Handle other cost kinds. if (CostKind != TTI::TCK_RecipThroughput) return Cost; - return vectorCostAdjustment(Cost, Opcode, ValTy, nullptr); + return Cost * CostFactor; } InstructionCost PPCTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, @@ -1049,8 +1086,12 @@ int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + InstructionCost CostFactor = vectorCostAdjustmentFactor(Opcode, Val, nullptr); + if (!CostFactor.isValid()) + return InstructionCost::getMax(); + InstructionCost Cost = BaseT::getVectorInstrCost(Opcode, Val, Index); - Cost = vectorCostAdjustment(Cost, Opcode, Val, nullptr); + Cost *= CostFactor; if (ST->hasVSX() && Val->getScalarType()->isDoubleTy()) { // Double-precision scalars are already located in index #0 (or #1 if LE). @@ -1065,7 +1106,7 @@ if (ISD == ISD::INSERT_VECTOR_ELT) // A move-to VSR and a permute/insert. Assume vector operation cost // for both (cost will be 2x on P9). - return vectorCostAdjustment(2, Opcode, Val, nullptr); + return 2 * CostFactor; // It's an extract. Maybe we can do a cheap move-from VSR. unsigned EltSize = Val->getScalarSizeInBits(); @@ -1082,7 +1123,7 @@ // We need a vector extract (or mfvsrld). Assume vector operation cost. // The cost of the load constant for a vector extract is disregarded // (invariant, easily schedulable). - return vectorCostAdjustment(1, Opcode, Val, nullptr); + return CostFactor; } else if (ST->hasDirectMove()) // Assume permute has standard cost. @@ -1114,6 +1155,11 @@ unsigned AddressSpace, TTI::TargetCostKind CostKind, const Instruction *I) { + + InstructionCost CostFactor = vectorCostAdjustmentFactor(Opcode, Src, nullptr); + if (!CostFactor.isValid()) + return InstructionCost::getMax(); + if (TLI->getValueType(DL, Src, true) == MVT::Other) return BaseT::getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, CostKind); @@ -1128,7 +1174,7 @@ if (CostKind != TTI::TCK_RecipThroughput) return Cost; - Cost = vectorCostAdjustment(Cost, Opcode, Src, nullptr); + Cost *= CostFactor; bool IsAltivecType = ST->hasAltivec() && (LT.second == MVT::v16i8 || LT.second == MVT::v8i16 || @@ -1194,6 +1240,11 @@ unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, bool UseMaskForCond, bool UseMaskForGaps) { + InstructionCost CostFactor = + vectorCostAdjustmentFactor(Opcode, VecTy, nullptr); + if (!CostFactor.isValid()) + return InstructionCost::getMax(); + if (UseMaskForCond || UseMaskForGaps) return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, Alignment, AddressSpace, CostKind, diff --git a/llvm/test/Transforms/PGOProfile/ppc-prevent-mma-types.ll b/llvm/test/Transforms/PGOProfile/ppc-prevent-mma-types.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/PGOProfile/ppc-prevent-mma-types.ll @@ -0,0 +1,204 @@ +; RUN: opt --vec-extabi=true -passes='default' -mcpu=pwr10 \ +; RUN: -pgo-kind=pgo-instr-gen-pipeline -mtriple=powerpc-ibm-aix -S < %s | \ +; RUN: FileCheck %s +; RUN: opt -passes='default' -mcpu=pwr10 -pgo-kind=pgo-instr-gen-pipeline \ +; RUN: -mtriple=powerpc64le-unknown-linux-gnu -S < %s | FileCheck %s + +; When running this test case under opt + PGO, the SLPVectorizer previously had +; an opportunity to produce wide vector types (such as <256 x i1>) within the +; IR as it deemed these wide vector types to be cheap enough to produce. +; Having this test ensures that the optimizer no longer generates wide vectors +; within the IR. + +%0 = type <{ double }> +%1 = type <{ [1 x %0]*, i8, i8, i8, i8, i32, i32, i32, [1 x i32], [1 x i32], [1 x i32], [24 x i8] }> +declare i8* @__malloc() +; CHECK-NOT: <256 x i1> +; CHECK-NOT: <512 x i1> +define dso_local void @test([0 x %0]* %arg, i32* %arg1, i32* %arg2, i32* %arg3, i32* %arg4) { + %i = alloca [0 x %0]*, align 4 + store [0 x %0]* %arg, [0 x %0]** %i, align 4 + %i7 = alloca i32*, align 4 + store i32* %arg1, i32** %i7, align 4 + %i9 = alloca i32*, align 4 + store i32* %arg2, i32** %i9, align 4 + %i10 = alloca i32*, align 4 + store i32* %arg3, i32** %i10, align 4 + %i11 = alloca i32*, align 4 + store i32* %arg4, i32** %i11, align 4 + %i14 = alloca %1, align 4 + %i15 = alloca i32, align 4 + %i16 = alloca i32, align 4 + %i17 = alloca i32, align 4 + %i18 = alloca i32, align 4 + %i20 = alloca i32, align 4 + %i21 = alloca i32, align 4 + %i22 = alloca i32, align 4 + %i23 = alloca i32, align 4 + %i25 = alloca double, align 8 + %i26 = load i32*, i32** %i9, align 4 + %i27 = load i32, i32* %i26, align 4 + %i28 = select i1 false, i32 0, i32 %i27 + store i32 %i28, i32* %i15, align 4 + %i29 = load i32*, i32** %i7, align 4 + %i30 = load i32, i32* %i29, align 4 + %i31 = select i1 false, i32 0, i32 %i30 + store i32 %i31, i32* %i16, align 4 + %i32 = load i32, i32* %i15, align 4 + %i33 = mul i32 8, %i32 + store i32 %i33, i32* %i17, align 4 + %i34 = load i32, i32* %i17, align 4 + %i35 = load i32, i32* %i16, align 4 + %i36 = mul i32 %i34, %i35 + store i32 %i36, i32* %i18, align 4 + %i37 = load i32*, i32** %i9, align 4 + %i38 = load i32, i32* %i37, align 4 + %i39 = select i1 false, i32 0, i32 %i38 + store i32 %i39, i32* %i22, align 4 + %i40 = load i32*, i32** %i10, align 4 + %i41 = load i32, i32* %i40, align 4 + %i42 = select i1 false, i32 0, i32 %i41 + store i32 %i42, i32* %i23, align 4 + %i43 = getelementptr inbounds %1, %1* %i14, i32 0, i32 10 + %i44 = bitcast [1 x i32]* %i43 to i8* + %i45 = getelementptr i8, i8* %i44, i32 -12 + %i46 = getelementptr inbounds i8, i8* %i45, i32 12 + %i47 = bitcast i8* %i46 to i32* + %i48 = load i32, i32* %i23, align 4 + %i49 = select i1 false, i32 0, i32 %i48 + %i50 = load i32, i32* %i22, align 4 + %i51 = select i1 false, i32 0, i32 %i50 + %i52 = mul i32 %i51, 8 + %i53 = mul i32 %i49, %i52 + store i32 %i53, i32* %i47, align 4 + %i54 = getelementptr inbounds %1, %1* %i14, i32 0, i32 10 + %i55 = bitcast [1 x i32]* %i54 to i8* + %i56 = getelementptr i8, i8* %i55, i32 -12 + %i57 = getelementptr inbounds i8, i8* %i56, i32 36 + %i58 = bitcast i8* %i57 to i32* + store i32 8, i32* %i58, align 4 + %i60 = getelementptr inbounds %1, %1* %i14, i32 0, i32 0 + %i61 = call i8* @__malloc() + %i62 = bitcast [1 x %0]** %i60 to i8** + store i8* %i61, i8** %i62, align 4 + br label %bb63 +bb63: ; preds = %bb66, %bb + %i64 = load i32*, i32** %i11, align 4 + %i65 = load i32, i32* %i64, align 4 + br label %bb66 +bb66: ; preds = %bb165, %bb63 + %i67 = load i32, i32* %i21, align 4 + %i68 = icmp sle i32 %i67, %i65 + br i1 %i68, label %bb69, label %bb63 +bb69: ; preds = %bb66 + store i32 1, i32* %i20, align 4 + br label %bb70 +bb70: ; preds = %bb163, %bb69 + %i71 = load i32, i32* %i20, align 4 + %i72 = icmp sle i32 %i71, 11 + br i1 %i72, label %bb73, label %bb165 +bb73: ; preds = %bb70 + %i74 = load i32, i32* %i21, align 4 + %i76 = mul i32 %i74, 8 + %i77 = getelementptr inbounds i8, i8* null, i32 %i76 + %i78 = bitcast i8* %i77 to double* + %i79 = load double, double* %i78, align 8 + %i80 = fcmp fast olt double %i79, 0.000000e+00 + %i81 = zext i1 %i80 to i32 + %i82 = trunc i32 %i81 to i1 + br i1 %i82, label %bb83, label %bb102 +bb83: ; preds = %bb73 + %i84 = getelementptr inbounds %1, %1* %i14, i32 0, i32 0 + %i85 = load [1 x %0]*, [1 x %0]** %i84, align 4 + %i86 = bitcast [1 x %0]* %i85 to i8* + %i87 = getelementptr i8, i8* %i86, i32 0 + %i88 = load i32, i32* %i20, align 4 + %i89 = getelementptr inbounds %1, %1* %i14, i32 0, i32 10 + %i90 = getelementptr inbounds [1 x i32], [1 x i32]* %i89, i32 0, i32 0 + %i91 = load i32, i32* %i90, align 4 + %i92 = mul i32 %i88, %i91 + %i93 = getelementptr inbounds i8, i8* %i87, i32 %i92 + %i94 = getelementptr inbounds i8, i8* %i93, i32 0 + %i95 = load i32, i32* %i21, align 4 + %i96 = getelementptr inbounds %1, %1* %i14, i32 0, i32 10 + %i97 = getelementptr inbounds [1 x i32], [1 x i32]* %i96, i32 0, i32 6 + %i98 = load i32, i32* %i97, align 4 + %i99 = mul i32 %i95, %i98 + %i100 = getelementptr inbounds i8, i8* %i94, i32 %i99 + %i101 = bitcast i8* %i100 to double* + store double 0.000000e+00, double* %i101, align 8 + br label %bb163 +bb102: ; preds = %bb73 + %i103 = getelementptr i8, i8* null, i32 -8 + %i104 = getelementptr inbounds i8, i8* %i103, i32 undef + %i105 = bitcast i8* %i104 to double* + %i106 = load double, double* %i105, align 8 + %i107 = load [0 x %0]*, [0 x %0]** %i, align 4 + %i108 = bitcast [0 x %0]* %i107 to i8* + %i109 = getelementptr i8, i8* %i108, i32 -8 + %i110 = getelementptr inbounds i8, i8* %i109, i32 undef + %i111 = bitcast i8* %i110 to double* + %i112 = load double, double* %i111, align 8 + %i113 = fmul fast double %i106, %i112 + %i114 = fcmp fast ogt double 0.000000e+00, %i113 + %i115 = zext i1 %i114 to i32 + %i116 = trunc i32 %i115 to i1 + br i1 %i116, label %bb117, label %bb136 +bb117: ; preds = %bb102 + %i118 = getelementptr inbounds %1, %1* %i14, i32 0, i32 0 + %i119 = load [1 x %0]*, [1 x %0]** %i118, align 4 + %i120 = bitcast [1 x %0]* %i119 to i8* + %i121 = getelementptr i8, i8* %i120, i32 0 + %i122 = load i32, i32* %i20, align 4 + %i123 = getelementptr inbounds %1, %1* %i14, i32 0, i32 10 + %i124 = getelementptr inbounds [1 x i32], [1 x i32]* %i123, i32 0, i32 0 + %i125 = load i32, i32* %i124, align 4 + %i126 = mul i32 %i122, %i125 + %i127 = getelementptr inbounds i8, i8* %i121, i32 %i126 + %i128 = getelementptr inbounds i8, i8* %i127, i32 0 + %i129 = load i32, i32* %i21, align 4 + %i130 = getelementptr inbounds %1, %1* %i14, i32 0, i32 10 + %i131 = getelementptr inbounds [1 x i32], [1 x i32]* %i130, i32 0, i32 6 + %i132 = load i32, i32* %i131, align 4 + %i133 = mul i32 %i129, %i132 + %i134 = getelementptr inbounds i8, i8* %i128, i32 %i133 + %i135 = bitcast i8* %i134 to double* + store double 0.000000e+00, double* %i135, align 8 + br label %bb163 +bb136: ; preds = %bb102 + %i137 = load double, double* null, align 8 + %i138 = load double, double* null, align 8 + %i139 = fmul fast double %i137, %i138 + %i140 = fsub fast double 0.000000e+00, %i139 + store double %i140, double* %i25, align 8 + %i141 = load i32, i32* %i21, align 4 + %i143 = getelementptr inbounds [1 x i32], [1 x i32]* null, i32 0, i32 6 + %i144 = load i32, i32* %i143, align 4 + %i145 = mul i32 %i141, %i144 + %i146 = getelementptr inbounds i8, i8* null, i32 %i145 + %i147 = bitcast i8* %i146 to double* + %i148 = load i32, i32* %i20, align 4 + %i149 = load i32, i32* %i18, align 4 + %i151 = mul i32 %i148, %i149 + %i152 = getelementptr i8, i8* null, i32 %i151 + %i153 = getelementptr i8, i8* %i152, i32 0 + %i154 = getelementptr inbounds i8, i8* %i153, i32 0 + %i155 = bitcast i8* %i154 to double* + %i156 = load double, double* %i155, align 8 + %i157 = load double, double* %i25, align 8 + %i158 = fmul fast double %i156, %i157 + %i159 = fadd fast double 0.000000e+00, %i158 + %i160 = load double, double* %i25, align 8 + %i161 = fadd fast double 0.000000e+00, %i160 + %i162 = fdiv fast double %i159, %i161 + store double %i162, double* %i147, align 8 + br label %bb163 +bb163: ; preds = %bb136, %bb117, %bb83 + %i164 = add nsw i32 %i71, 1 + store i32 %i164, i32* %i20, align 4 + br label %bb70 +bb165: ; preds = %bb70 + %i166 = add nsw i32 %i67, 1 + store i32 %i166, i32* %i21, align 4 + br label %bb66 +}