diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "llvm/ADT/StringSet.h" #include @@ -461,18 +462,10 @@ /// Returns an attribute list that excludes pre-defined attributes. template SmallVector getPrunedAttributeList(OpTy op) { - llvm::StringSet<> elidedAttrs; - elidedAttrs.insert(op.getAttributeNames().begin(), - op.getAttributeNames().end()); + auto elidedAttrs = llvm::to_vector(op.getAttributeNames()); if (isa(op.getOperation())) - elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName); - SmallVector attrs; - for (auto attr : op->getAttrs()) { - if (elidedAttrs.count(attr.getName())) - continue; - attrs.push_back(attr); - } - return attrs; + elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName); + return getPrunedAttributeList(op, elidedAttrs); } } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1295,13 +1295,13 @@ let builders = [ // Build a PadOp with mixed static and dynamic entries. - OpBuilder<(ins "Value":$source, "ArrayRef":$staticLow, - "ArrayRef":$staticHigh, "ValueRange":$low, "ValueRange":$high, - CArg<"bool", "false">:$nofold, + OpBuilder<(ins "Type":$resultType, "Value":$source, + "ArrayRef":$staticLow, "ArrayRef":$staticHigh, + "ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold, CArg<"ArrayRef", "{}">:$attrs)>, // Build a PadOp with all dynamic entries. - OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high, - CArg<"bool", "false">:$nofold, + OpBuilder<(ins "Type":$resultType, "Value":$source, "ValueRange":$low, + "ValueRange":$high, CArg<"bool", "false">:$nofold, CArg<"ArrayRef", "{}">:$attrs)>, // Build a PadOp with mixed static and dynamic entries and custom // result type. If the type passed is nullptr, it is inferred. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -123,6 +123,11 @@ TypeRange newResultTypes, ValueRange newOperands); +// Get the list of attributes associated with the op, ignoring +// those with the provided name. +SmallVector +getPrunedAttributeList(Operation *op, ArrayRef elidedAttrs); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2518,26 +2518,27 @@ return RankedTensorType::get(inferredShape, sourceType.getElementType()); } -void PadOp::build(OpBuilder &b, OperationState &result, Value source, - ArrayRef staticLow, ArrayRef staticHigh, - ValueRange low, ValueRange high, bool nofold, - ArrayRef attrs) { +void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value source, ArrayRef staticLow, + ArrayRef staticHigh, ValueRange low, ValueRange high, + bool nofold, ArrayRef attrs) { auto sourceType = source.getType().cast(); - auto resultType = inferResultType(sourceType, staticLow, staticHigh); + if (!resultType) + resultType = inferResultType(sourceType, staticLow, staticHigh); build(b, result, resultType, source, low, high, b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } -void PadOp::build(OpBuilder &b, OperationState &result, Value source, - ValueRange low, ValueRange high, bool nofold, +void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value source, ValueRange low, ValueRange high, bool nofold, ArrayRef attrs) { auto sourceType = source.getType().cast(); unsigned rank = sourceType.getRank(); SmallVector staticVector(rank, ShapedType::kDynamic); - build(b, result, source, staticVector, staticVector, low, high, nofold, - attrs); + build(b, result, resultType, source, staticVector, staticVector, low, high, + nofold, attrs); } void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, @@ -2635,9 +2636,9 @@ } else { auto newOp = rewriter.create( padTensorOp->getLoc(), newResultType, padTensorOp.getSource(), - padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), - padTensorOp.getNofold()); + padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(), + getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); IRMapping mapper; padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); @@ -2667,9 +2668,10 @@ auto replacementOp = rewriter.create( padTensorOp.getLoc(), tensorCastOp.getDest().getType(), - padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(), - padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), - padTensorOp.getNofold()); + padTensorOp.getSource(), padTensorOp.getStaticLow(), + padTensorOp.getStaticHigh(), padTensorOp.getLow(), + padTensorOp.getHigh(), padTensorOp.getNofold(), + getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); replacementOp.getRegion().takeBody(padTensorOp.getRegion()); rewriter.replaceOp(padTensorOp, replacementOp.getResult()); @@ -2827,7 +2829,8 @@ innerSliceOp.getMixedStrides()); auto newPadOp = rewriter.create( padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(), - padOp.getMixedLowPad(), newHighPad, padOp.getNofold()); + padOp.getMixedLowPad(), newHighPad, padOp.getNofold(), + getPrunedAttributeList(padOp, PadOp::getAttributeNames())); rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(), newPadOp.getRegion().begin()); rewriter.replaceOp(padOp, newPadOp.getResult()); @@ -2916,8 +2919,9 @@ auto newResultType = RankedTensorType::get( newOutDims, padTensorOp.getType().getElementType()); auto newOp = rewriter.create( - padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(), - padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold()); + padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh, + padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(), + getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); IRMapping mapper; padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" +#include "llvm/ADT/StringSet.h" #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc" @@ -114,3 +115,16 @@ state.addRegion(); return b.create(state); } + +SmallVector +mlir::getPrunedAttributeList(Operation *op, ArrayRef elidedAttrs) { + llvm::StringSet elidedAttrsSet; + elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end()); + SmallVector attrs; + for (auto attr : op->getAttrs()) { + if (elidedAttrsSet.count(attr.getName())) + continue; + attrs.push_back(attr); + } + return attrs; +}