diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -232,15 +232,6 @@ SmallVector getOrCreateRanges(OpBuilder &b, Location loc) { return mlir::getOrCreateRanges(*this, b, loc); } - - static ArrayRef getSpecialAttrNames() { - static SmallVector names{ - OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), - OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), - OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), - getOperandSegmentSizeAttr()}; - return names; - } }]; } diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -29,6 +29,24 @@ class OffsetSizeAndStrideOpInterface; LogicalResult verify(OffsetSizeAndStrideOpInterface op); +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/ViewLikeInterface.h.inc" + +namespace mlir { +/// Print part of an op of the form: +/// ``` +/// `[` offset-list `]` +/// `[` size-list `]` +/// [` stride-list `]` +/// ``` +void printOffsetsSizesAndStrides( + OpAsmPrinter &p, OffsetSizeAndStrideOpInterface op, + StringRef offsetPrefix = "", StringRef sizePrefix = " ", + StringRef stridePrefix = " ", + ArrayRef elidedAttrs = + OffsetSizeAndStrideOpInterface::getSpecialAttrNames()); /// Parse trailing part of an op of the form: /// ``` @@ -59,10 +77,16 @@ nullptr, llvm::function_ref parseOptionalStridePrefix = nullptr); +/// `preResolutionFn`-less version of `parseOffsetsSizesAndStrides`. +ParseResult parseOffsetsSizesAndStrides( + OpAsmParser &parser, OperationState &result, ArrayRef segmentSizes, + llvm::function_ref parseOptionalOffsetPrefix = + nullptr, + llvm::function_ref parseOptionalSizePrefix = + nullptr, + llvm::function_ref parseOptionalStridePrefix = + nullptr); } // namespace mlir -/// Include the generated interface declarations. -#include "mlir/Interfaces/ViewLikeInterface.h.inc" - #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -357,6 +357,14 @@ static StringRef getStaticStridesAttrName() { return "static_strides"; } + static ArrayRef getSpecialAttrNames() { + static SmallVector names{ + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), + OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()}; + return names; + } }]; let verify = [{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -793,12 +793,6 @@ return fusableDependences; } -static bool isZero(Value v) { - if (auto cst = v.getDefiningOp()) - return cst.getValue() == 0; - return false; -} - /// Tile the fused loops in the root operation, by setting the tile sizes for /// all other loops to zero (those will be tiled later). static Optional tileRootOperation( diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -248,49 +248,6 @@ [](APInt a, APInt b) { return a + b; }); } -//===----------------------------------------------------------------------===// -// BaseOpWithOffsetSizesAndStridesOp -//===----------------------------------------------------------------------===// - -/// Print a list with either (1) the static integer value in `arrayAttr` if -/// `isDynamic` evaluates to false or (2) the next value otherwise. -/// This allows idiomatic printing of mixed value and integer attributes in a -/// list. E.g. `[%arg0, 7, 42, %arg42]`. -static void -printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, - ArrayAttr arrayAttr, - llvm::function_ref isDynamic) { - p << '['; - unsigned idx = 0; - llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { - int64_t val = a.cast().getInt(); - if (isDynamic(val)) - p << values[idx++]; - else - p << val; - }); - p << ']'; -} - -/// Verify that a particular offset/size/stride static attribute is well-formed. -static LogicalResult verifyOpWithOffsetSizesAndStridesPart( - OffsetSizeAndStrideOpInterface op, StringRef name, - unsigned expectedNumElements, StringRef attrName, ArrayAttr attr, - llvm::function_ref isDynamic, ValueRange values) { - /// Check static and dynamic offsets/sizes/strides breakdown. - if (attr.size() != expectedNumElements) - return op.emitError("expected ") - << expectedNumElements << " " << name << " values"; - unsigned expectedNumDynamicEntries = - llvm::count_if(attr.getValue(), [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - if (values.size() != expectedNumDynamicEntries) - return op.emitError("expected ") - << expectedNumDynamicEntries << " dynamic " << name << " values"; - return success(); -} - /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. static SmallVector extractFromI64ArrayAttr(Attribute attr) { return llvm::to_vector<4>( @@ -2390,9 +2347,9 @@ staticStridesVector, offset, sizes, strides, attrs); } -/// Print of the form: +/// Print a memref_reinterpret_cast op of the form: /// ``` -/// `name` ssa-name to +/// `memref_reinterpret_cast` ssa-name to /// offset: `[` offset `]` /// sizes: `[` size-list `]` /// strides:`[` stride-list `]` @@ -2400,19 +2357,11 @@ /// ``` static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op.getOperationName().drop_front(stdDotLen) << " " << op.source() - << " to offset: "; - printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset); - p << ", sizes: "; - printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), - ShapedType::isDynamic); - p << ", strides: "; - printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset); - p.printOptionalAttrDict( - op.getAttrs(), - /*elidedAttrs=*/MemRefReinterpretCastOp::getSpecialAttrNames()); + p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; + p << op.source() << " "; + printOffsetsSizesAndStrides( + p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ", + /*stridePrefix=*/", strides: "); p << ": " << op.source().getType() << " to " << op.getType(); } @@ -2451,8 +2400,8 @@ parser.parseKeywordType("to", dstType) || parser.resolveOperand(srcInfo, srcType, result.operands)); }; - SmallVector segmentSizes{1}; // source memref - if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes, + if (failed(parseOffsetsSizesAndStrides(parser, result, + /*segmentSizes=*/{1}, // source memref preResolutionFn, parseOffsetPrefix, parseSizePrefix, parseStridePrefix))) return failure(); @@ -3122,38 +3071,18 @@ sourceMemRefType.getMemorySpace()); } -/// Print SubViewOp in the form: +/// Print a subview op of the form: /// ``` -/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `subview` ssa-name +/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` /// `:` strided-memref-type `to` strided-memref-type /// ``` -template -static void printOpWithOffsetsSizesAndStrides( - OpAsmPrinter &p, OpType op, - llvm::function_ref printExtraOperands = - [](OpAsmPrinter &p, OpType op) {}, - StringRef resultTypeKeyword = "to") { +static void print(OpAsmPrinter &p, SubViewOp op) { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; p << op.source(); - printExtraOperands(p, op); - printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset); - p << ' '; - printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), - ShapedType::isDynamic); - p << ' '; - printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset); - p << ' '; - p.printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{OpType::getSpecialAttrNames()}); - p << " : " << op.getSourceType() << " " << resultTypeKeyword << " " - << op.getType(); -} - -static void print(OpAsmPrinter &p, SubViewOp op) { - return printOpWithOffsetsSizesAndStrides(p, op); + printOffsetsSizesAndStrides(p, op); + p << " : " << op.getSourceType() << " to " << op.getType(); } /// Parse a subview op of the form: @@ -3173,8 +3102,9 @@ parser.parseKeywordType("to", dstType) || parser.resolveOperand(srcInfo, srcType, result.operands)); }; - SmallVector segmentSizes{1}; // source memref - if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes, + + if (failed(parseOffsetsSizesAndStrides(parser, result, + /*segmentSizes=*/{1}, // source memref preResolutionFn))) return failure(); return parser.addTypeToList(dstType, result.types); @@ -3750,8 +3680,18 @@ // SubTensorOp //===----------------------------------------------------------------------===// +/// Print a subtensor op of the form: +/// ``` +/// `subtensor` ssa-name +/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` ranked-tensor-type `to` ranked-tensor-type +/// ``` static void print(OpAsmPrinter &p, SubTensorOp op) { - return printOpWithOffsetsSizesAndStrides(p, op); + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; + p << op.source(); + printOffsetsSizesAndStrides(p, op); + p << " : " << op.getSourceType() << " to " << op.getType(); } /// Parse a subtensor op of the form: @@ -3772,8 +3712,9 @@ parser.parseKeywordType("to", dstType) || parser.resolveOperand(srcInfo, srcType, result.operands)); }; - SmallVector segmentSizes{1}; // source tensor - if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes, + + if (failed(parseOffsetsSizesAndStrides(parser, result, + /*segmentSizes=*/{1}, // source tensor preResolutionFn))) return failure(); return parser.addTypeToList(dstType, result.types); @@ -3853,11 +3794,18 @@ // SubTensorInsertOp //===----------------------------------------------------------------------===// +/// Print a subtensor_insert op of the form: +/// ``` +/// `subtensor_insert` ssa-name `into` ssa-name +/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` ranked-tensor-type `into` ranked-tensor-type +/// ``` static void print(OpAsmPrinter &p, SubTensorInsertOp op) { - return printOpWithOffsetsSizesAndStrides( - p, op, - [](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest(); }, - /*resultTypeKeyword=*/"into"); + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; + p << op.source() << " into " << op.dest(); + printOffsetsSizesAndStrides(p, op); + p << " : " << op.getSourceType() << " into " << op.getType(); } /// Parse a subtensor_insert op of the form: @@ -3880,9 +3828,11 @@ parser.resolveOperand(srcInfo, srcType, result.operands) || parser.resolveOperand(dstInfo, dstType, result.operands)); }; - SmallVector segmentSizes{1, 1}; // source tensor, destination tensor - if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes, - preResolutionFn))) + + if (failed(parseOffsetsSizesAndStrides( + parser, result, + /*segmentSizes=*/{1, 1}, // source tensor, destination tensor + preResolutionFn))) return failure(); return parser.addTypeToList(dstType, result.types); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -57,6 +57,44 @@ return success(); } +/// Print a list with either (1) the static integer value in `arrayAttr` if +/// `isDynamic` evaluates to false or (2) the next value otherwise. +/// This allows idiomatic printing of mixed value and integer attributes in a +/// list. E.g. `[%arg0, 7, 42, %arg42]`. +static void +printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, + ArrayAttr arrayAttr, + llvm::function_ref isDynamic) { + p << '['; + unsigned idx = 0; + llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { + int64_t val = a.cast().getInt(); + if (isDynamic(val)) + p << values[idx++]; + else + p << val; + }); + p << ']'; +} + +void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p, + OffsetSizeAndStrideOpInterface op, + StringRef offsetPrefix, + StringRef sizePrefix, + StringRef stridePrefix, + ArrayRef elidedAttrs) { + p << offsetPrefix; + printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset); + p << sizePrefix; + printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + p << stridePrefix; + printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset); + p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); +} + /// Parse a mixed list with either (1) static integer values or (2) SSA values. /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` /// encode the position of SSA values. Add the parsed SSA values to `ssa` @@ -105,9 +143,17 @@ } ParseResult mlir::parseOffsetsSizesAndStrides( - OpAsmParser &parser, - OperationState &result, - ArrayRef segmentSizes, + OpAsmParser &parser, OperationState &result, ArrayRef segmentSizes, + llvm::function_ref parseOptionalOffsetPrefix, + llvm::function_ref parseOptionalSizePrefix, + llvm::function_ref parseOptionalStridePrefix) { + return parseOffsetsSizesAndStrides( + parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix, + parseOptionalSizePrefix, parseOptionalStridePrefix); +} + +ParseResult mlir::parseOffsetsSizesAndStrides( + OpAsmParser &parser, OperationState &result, ArrayRef segmentSizes, llvm::function_ref preResolutionFn, llvm::function_ref parseOptionalOffsetPrefix, @@ -132,14 +178,14 @@ ShapedType::kDynamicStrideOrOffset, stridesInfo)) return failure(); // Add segment sizes to result - SmallVector segmentSizesFinal(segmentSizes.begin(), segmentSizes.end()); + SmallVector segmentSizesFinal(segmentSizes.begin(), + segmentSizes.end()); segmentSizesFinal.append({static_cast(offsetsInfo.size()), - static_cast(sizesInfo.size()), - static_cast(stridesInfo.size())}); - auto b = parser.getBuilder(); + static_cast(sizesInfo.size()), + static_cast(stridesInfo.size())}); result.addAttribute( OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), - b.getI32VectorAttr(segmentSizesFinal)); + parser.getBuilder().getI32VectorAttr(segmentSizesFinal)); return failure( (preResolutionFn && preResolutionFn(parser, result)) || parser.resolveOperands(offsetsInfo, indexType, result.operands) ||