diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -14,6 +14,7 @@ #define LLVM_ANALYSIS_VECTORUTILS_H #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Support/CheckedArithmetic.h" @@ -544,20 +545,20 @@ /// elements, it will be padded with undefs. Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef Vecs); -/// Given a mask vector of the form , Return true if all of the +/// Given a mask vector of the form , Return true if all of the /// elements of this predicate mask are false or undef. That is, return true -/// if all lanes can be assumed inactive. +/// if all lanes can be assumed inactive. bool maskIsAllZeroOrUndef(Value *Mask); -/// Given a mask vector of the form , Return true if all of the +/// Given a mask vector of the form , Return true if all of the /// elements of this predicate mask are true or undef. That is, return true -/// if all lanes can be assumed active. +/// if all lanes can be assumed active. bool maskIsAllOneOrUndef(Value *Mask); -/// Given a mask vector of the form , return an APInt (of bitwidth Y) -/// for each lane which may be active. -APInt possiblyDemandedEltsInMask(Value *Mask); - +/// Given a mask vector of the form , return an APInt (of bitwidth Y) +/// for each lane which may be active. Given a scalable vector, returns None. +Optional possiblyDemandedEltsInMask(Value *Mask); + /// The group of interleaved loads/stores sharing the same stride and /// close to each other. /// diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -868,6 +868,8 @@ return false; if (ConstMask->isNullValue() || isa(ConstMask)) return true; + if (isa(ConstMask->getType())) + return false; for (unsigned I = 0, E = cast(ConstMask->getType())->getNumElements(); @@ -887,6 +889,8 @@ return false; if (ConstMask->isAllOnesValue() || isa(ConstMask)) return true; + if (isa(ConstMask->getType())) + return false; for (unsigned I = 0, E = cast(ConstMask->getType())->getNumElements(); @@ -901,7 +905,9 @@ /// TODO: This is a lot like known bits, but for /// vectors. Is there something we can common this with? -APInt llvm::possiblyDemandedEltsInMask(Value *Mask) { +Optional llvm::possiblyDemandedEltsInMask(Value *Mask) { + if (isa(Mask->getType())) + return {}; const unsigned VWidth = cast(Mask->getType())->getNumElements(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -320,11 +320,13 @@ } // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts - APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); - APInt UndefElts(DemandedElts.getBitWidth(), 0); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), - DemandedElts, UndefElts)) - return replaceOperand(II, 0, V); + Optional DemandedElts = possiblyDemandedEltsInMask(ConstMask); + if (DemandedElts) { + APInt UndefElts(DemandedElts->getBitWidth(), 0); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), *DemandedElts, + UndefElts)) + return replaceOperand(II, 0, V); + } return nullptr; } @@ -356,14 +358,16 @@ return eraseInstFromFunction(II); // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts - APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); - APInt UndefElts(DemandedElts.getBitWidth(), 0); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), - DemandedElts, UndefElts)) - return replaceOperand(II, 0, V); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), - DemandedElts, UndefElts)) - return replaceOperand(II, 1, V); + Optional DemandedElts = possiblyDemandedEltsInMask(ConstMask); + if (DemandedElts) { + APInt UndefElts(DemandedElts->getBitWidth(), 0); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), *DemandedElts, + UndefElts)) + return replaceOperand(II, 0, V); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), *DemandedElts, + UndefElts)) + return replaceOperand(II, 1, V); + } return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/AArch64/VectorUtils_heuristics.ll b/llvm/test/Transforms/InstCombine/AArch64/VectorUtils_heuristics.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/VectorUtils_heuristics.ll @@ -0,0 +1,21 @@ +; RUN: opt -s -instcombine < %s | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-unknown-linux-gnu" + +; This test checks that instcombine does not crash while invoking +; maskIsAllOneOrUndef, maskIsAllZeroOrUndef, or possiblyDemandedEltsInMask. + +; CHECK-LABEL: novel_algorithm +; CHECK: unreachable +define void @novel_algorithm() { +entry: + %a = call @llvm.masked.load.nxv16i8.p0nxv16i8(* undef, i32 1, shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer), undef) + %b = add undef, %a + call void @llvm.masked.store.nxv16i8.p0nxv16i8( %b, * undef, i32 1, shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer)) + unreachable +} + +declare @llvm.masked.load.nxv16i8.p0nxv16i8(*, i32 immarg, , ) + +declare void @llvm.masked.store.nxv16i8.p0nxv16i8(, *, i32 immarg, )