diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -242,6 +242,9 @@ Type type, Value source, Value pad, ArrayRef low, ArrayRef high, Location loc, OpBuilder & builder); + // Return the pad value if it is a constant. Return null value otherwise. + Value getConstantPaddingValue(); + // Return a vector of all the static or dynamic values (low/high padding) of // the op. inline SmallVector getMixedPadImpl(ArrayAttr staticAttrs, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1141,6 +1141,30 @@ results.add(context); } +/// Return the padding value of the PadTensorOp if it constant. In this context, +/// "constant" means an actual constant or "defined outside of the block". +/// +/// Values are considered constant in three cases: +/// - A ConstantLike value. +/// - A basic block argument from a different block. +/// - A value defined outside of the block. +/// +/// If the padding value is not constant, an empty Value is returned. +Value PadTensorOp::getConstantPaddingValue() { + auto yieldOp = dyn_cast(getRegion().front().getTerminator()); + if (!yieldOp || yieldOp.values().size() != 1) + return {}; + Value padValue = yieldOp.values().front(); + // Check if yield value is a constant. + if (matchPattern(padValue, m_Constant())) + return padValue; + // Check if yield value is defined inside the PadTensorOp block. + if (padValue.getParentBlock() == &getRegion().front()) + return {}; + // Else: Yield value defined outside of the PadTensorOp block. + return padValue; +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -650,31 +650,6 @@ // Misc. vectorization patterns. //----------------------------------------------------------------------------// -/// Given a block, return the Value that the block yields if that Value is -/// constant. In this context, "constant" means "defined outside of the block". -/// Should not be called on blocks that yield more than one value. -/// -/// Values are considered constant in two cases: -/// - A basic block argument from a different block. -/// - A value defined outside of the block. -/// -/// If the yielded value is not constant, an empty Value is returned. -static Value getConstantYieldValueFromBlock(Block &block) { - auto yieldOp = cast(block.getTerminator()); - assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); - Value result = yieldOp.values().front(); - Operation *definingOp = result.getDefiningOp(); - - // Check if yield value is defined inside the block. - if (definingOp && definingOp->getBlock() == &block) - return Value(); - // Check if the yield value is a BB arg of the block. - if (!definingOp && result.cast().getOwner() == &block) - return Value(); - - return result; -} - /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and /// TransferWriteOp. For now, this only applies when all low and high paddings /// are determined to be zero. @@ -693,7 +668,7 @@ // High padding must be static 0. if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure(); // Pad value must be a constant. - auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + auto padValue = padOp.getConstantPaddingValue(); if (!padValue) return failure(); // Bail on non-static shapes.