diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -45,6 +45,11 @@ let results = (outs AnyTensor:$result); + let assemblyFormat = [{ + custom($sizes, $static_sizes) attr-dict + `:` type($result) + }]; + let verifier = [{ return ::verify(*this); }]; let extraClassDeclaration = [{ @@ -118,7 +123,7 @@ } def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", - [AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> { + [AttrSizedOperandSegments]> { let summary = "tensor pad operation"; let description = [{ `linalg.pad_tensor` is an operation that pads the `source` tensor @@ -181,10 +186,16 @@ I64ArrayAttr:$static_low, I64ArrayAttr:$static_high); - let regions = (region AnyRegion:$region); + let regions = (region SizedRegion<1>:$region); let results = (outs AnyTensor:$result); + let assemblyFormat = [{ + $source `low` `` custom($low, $static_low) + `high` `` custom($high, $static_high) + $region attr-dict `:` type($source) `to` type($result) + }]; + let extraClassDeclaration = [{ static StringRef getStaticLowAttrName() { return "static_low"; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1956,6 +1956,19 @@ ); let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $source `to` `offset` `` `:` + custom($offsets, $static_offsets) + `` `,` `sizes` `` `:` + custom($sizes, $static_sizes) `` `,` `strides` + `` `:` + custom($strides, $static_strides) + attr-dict `:` type($source) `to` type($result) + }]; + + let parser=?; + let printer=?; + let builders = [ // Build a ReinterpretCastOp with mixed static and dynamic entries. OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source, @@ -2931,6 +2944,14 @@ ); let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $source `` + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict `:` type($source) `to` type($result) + }]; + let builders = [ // Build a SubViewOp with mixed static and dynamic entries and custom // result type. If the type passed is nullptr, it is inferred. @@ -3053,6 +3074,14 @@ ); let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $source `` + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict `:` type($source) `to` type($result) + }]; + let builders = [ // Build a SubTensorOp with mixed static and dynamic entries and inferred // result type. @@ -3115,7 +3144,10 @@ //===----------------------------------------------------------------------===// def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides< - "subtensor_insert", [OffsetSizeAndStrideOpInterface]> { + "subtensor_insert", + [OffsetSizeAndStrideOpInterface, + TypesMatchWith<"expected result type to match dest type", + "dest", "result", "$_self">]> { let summary = "subtensor_insert operation"; let description = [{ The "subtensor_insert" operation insert a tensor `source` into another @@ -3159,6 +3191,16 @@ ); let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $source `into` $dest `` + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict `:` type($source) `into` type($dest) + }]; + + let verifier = ?; + let builders = [ // Build a SubTensorInsertOp with mixed static and dynamic entries. OpBuilderDAG<(ins "Value":$source, "Value":$dest, diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -36,78 +36,68 @@ #include "mlir/Interfaces/ViewLikeInterface.h.inc" namespace mlir { -/// Print a list with either (1) the static integer value in `arrayAttr` if -/// `isDynamic` evaluates to false or (2) the next value otherwise. -/// This allows idiomatic printing of mixed value and integer attributes in a + +/// Printer hook for custom directive in assemblyFormat. +/// +/// custom($values, $integers) +/// +/// where `values` is of ODS type `Variadic` and `integers` is of ODS +/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with +/// either (1) the static integer value in `integers` if the value is +/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise. This +/// allows idiomatic printing of mixed value and integer attributes in a /// list. E.g. `[%arg0, 7, 42, %arg42]`. -void printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, - ArrayAttr arrayAttr, - llvm::function_ref isDynamic); +void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer, + Operation *op, + OperandRange values, + ArrayAttr integers); -/// Print part of an op of the form: -/// ``` -/// `[` offset-list `]` -/// `[` size-list `]` -/// [` stride-list `]` -/// ``` -void printOffsetsSizesAndStrides( - OpAsmPrinter &p, OffsetSizeAndStrideOpInterface op, - StringRef offsetPrefix = "", StringRef sizePrefix = " ", - StringRef stridePrefix = " ", - ArrayRef elidedAttrs = - OffsetSizeAndStrideOpInterface::getSpecialAttrNames()); +/// Printer hook for custom directive in assemblyFormat. +/// +/// custom($values, $integers) +/// +/// where `values` is of ODS type `Variadic` and `integers` is of ODS +/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with +/// either (1) the static integer value in `integers` if the value is +/// ShapedType::kDynamicSize or (2) the next value otherwise. This +/// allows idiomatic printing of mixed value and integer attributes in a +/// list. E.g. `[%arg0, 7, 42, %arg42]`. +void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op, + OperandRange values, ArrayAttr integers); -/// Parse a mixed list with either (1) static integer values or (2) SSA values. -/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` -/// encode the position of SSA values. Add the parsed SSA values to `ssa` -/// in-order. +/// Pasrer hook for custom directive in assemblyFormat. +/// +/// custom($values, $integers) +/// +/// where `values` is of ODS type `Variadic` and `integers` is of ODS +/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with +/// either (1) static integer values or (2) SSA values. Fill `integers` with +/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the +/// position of SSA values. Add the parsed SSA values to `values` in-order. // /// E.g. after parsing "[%arg0, 7, 42, %arg42]": /// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" /// 2. `ssa` is filled with "[%arg0, %arg1]". -ParseResult -parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, - StringRef attrName, int64_t dynVal, - SmallVectorImpl &ssa); +ParseResult parseOperandsOrIntegersOffsetsOrStridesList( + OpAsmParser &parser, SmallVectorImpl &values, + ArrayAttr &integers); -/// Parse trailing part of an op of the form: -/// ``` -/// `[` offset-list `]` -/// `[` size-list `]` -/// [` stride-list `]` -/// ``` -/// Each entry in the offset, size and stride list either resolves to an integer -/// constant or an operand of index type. -/// Constants are added to the `result` as named integer array attributes with -/// name `OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName()` (resp. -/// `getStaticSizesAttrName()`, `getStaticStridesAttrName()`). +/// Pasrer hook for custom directive in assemblyFormat. /// -/// Append the number of offset, size and stride operands to `segmentSizes` -/// before adding it to `result` as the named attribute: -/// `OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()`. +/// custom($values, $integers) /// -/// Offset, size and stride operands resolution occurs after `preResolutionFn` -/// to give a chance to leading operands to resolve first, after parsing the -/// types. -ParseResult parseOffsetsSizesAndStrides( - OpAsmParser &parser, OperationState &result, ArrayRef segmentSizes, - llvm::function_ref - preResolutionFn = nullptr, - llvm::function_ref parseOptionalOffsetPrefix = - nullptr, - llvm::function_ref parseOptionalSizePrefix = - nullptr, - llvm::function_ref parseOptionalStridePrefix = - nullptr); -/// `preResolutionFn`-less version of `parseOffsetsSizesAndStrides`. -ParseResult parseOffsetsSizesAndStrides( - OpAsmParser &parser, OperationState &result, ArrayRef segmentSizes, - llvm::function_ref parseOptionalOffsetPrefix = - nullptr, - llvm::function_ref parseOptionalSizePrefix = - nullptr, - llvm::function_ref parseOptionalStridePrefix = - nullptr); +/// where `values` is of ODS type `Variadic` and `integers` is of ODS +/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with +/// either (1) static integer values or (2) SSA values. Fill `integers` with +/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the +/// position of SSA values. Add the parsed SSA values to `values` in-order. +// +/// E.g. after parsing "[%arg0, 7, 42, %arg42]": +/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg1]". +ParseResult parseOperandsOrIntegersSizesList( + OpAsmParser &parser, SmallVectorImpl &values, + ArrayAttr &integers); /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -676,30 +676,6 @@ // InitTensorOp //===----------------------------------------------------------------------===// -static ParseResult parseInitTensorOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType srcInfo; - Type dstType; - SmallVector sizeInfo; - IndexType indexType = parser.getBuilder().getIndexType(); - if (failed(parseListOfOperandsOrIntegers( - parser, result, InitTensorOp::getStaticSizesAttrName(), - ShapedType::kDynamicSize, sizeInfo)) || - failed(parser.parseOptionalAttrDict(result.attributes)) || - failed(parser.parseColonType(dstType)) || - failed(parser.resolveOperands(sizeInfo, indexType, result.operands))) - return failure(); - return parser.addTypeToList(dstType, result.types); -} - -static void print(OpAsmPrinter &p, InitTensorOp op) { - p << op.getOperation()->getName() << ' '; - printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), - ShapedType::isDynamic); - p.printOptionalAttrDict(op.getAttrs(), - InitTensorOp::getStaticSizesAttrName()); - p << " : " << op.getType(); -} static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); @@ -981,8 +957,6 @@ } auto ®ion = op.region(); - if (!llvm::hasSingleElement(region)) - return op.emitOpError("expected region with 1 block"); unsigned rank = resultType.getRank(); Block &block = region.front(); if (block.getNumArguments() != rank) @@ -1020,67 +994,6 @@ return RankedTensorType::get(resultShape, sourceType.getElementType()); } -static ParseResult parsePadTensorOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType baseInfo; - SmallVector operands; - SmallVector types; - if (parser.parseOperand(baseInfo)) - return failure(); - - IndexType indexType = parser.getBuilder().getIndexType(); - SmallVector lowPadding, highPadding; - if (parser.parseKeyword("low") || - parseListOfOperandsOrIntegers(parser, result, - PadTensorOp::getStaticLowAttrName(), - ShapedType::kDynamicSize, lowPadding)) - return failure(); - if (parser.parseKeyword("high") || - parseListOfOperandsOrIntegers(parser, result, - PadTensorOp::getStaticHighAttrName(), - ShapedType::kDynamicSize, highPadding)) - return failure(); - - SmallVector regionOperands; - std::unique_ptr region = std::make_unique(); - SmallVector operandTypes, regionTypes; - if (parser.parseRegion(*region, regionOperands, regionTypes)) - return failure(); - result.addRegion(std::move(region)); - - Type srcType, dstType; - if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType)) - return failure(); - - if (parser.addTypeToList(dstType, result.types)) - return failure(); - - SmallVector segmentSizesFinal = {1}; // source tensor - segmentSizesFinal.append({static_cast(lowPadding.size()), - static_cast(highPadding.size())}); - result.addAttribute( - OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), - parser.getBuilder().getI32VectorAttr(segmentSizesFinal)); - return failure( - parser.parseOptionalAttrDict(result.attributes) || - parser.resolveOperand(baseInfo, srcType, result.operands) || - parser.resolveOperands(lowPadding, indexType, result.operands) || - parser.resolveOperands(highPadding, indexType, result.operands)); -} - -static void print(OpAsmPrinter &p, PadTensorOp op) { - p << op->getName().getStringRef() << ' '; - p << op.source(); - p << " low"; - printListOfOperandsOrIntegers(p, op.low(), op.static_low(), - ShapedType::isDynamic); - p << " high"; - printListOfOperandsOrIntegers(p, op.high(), op.static_high(), - ShapedType::isDynamic); - p.printRegion(op.region()); - p << " : " << op.source().getType() << " to " << op.getType(); -} - /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if /// it is a Value or into `staticVec` if it is an IntegerAttr. /// In the case of a Value, a copy of the `sentinel` value is also pushed to diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2148,67 +2148,6 @@ build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); } -/// Print a memref_reinterpret_cast op of the form: -/// ``` -/// `memref_reinterpret_cast` ssa-name to -/// offset: `[` offset `]` -/// sizes: `[` size-list `]` -/// strides:`[` stride-list `]` -/// `:` any-memref-type to strided-memref-type -/// ``` -static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) { - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op->getName().getStringRef().drop_front(stdDotLen) << ' '; - p << op.source() << " "; - printOffsetsSizesAndStrides( - p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ", - /*stridePrefix=*/", strides: "); - p << ": " << op.source().getType() << " to " << op.getType(); -} - -/// Parse a memref_reinterpret_cast op of the form: -/// ``` -/// `memref_reinterpret_cast` ssa-name to -/// offset: `[` offset `]` -/// sizes: `[` size-list `]` -/// strides:`[` stride-list `]` -/// `:` any-memref-type to strided-memref-type -/// ``` -static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser, - OperationState &result) { - // Parse `operand` - OpAsmParser::OperandType srcInfo; - if (parser.parseOperand(srcInfo)) - return failure(); - - auto parseOffsetPrefix = [](OpAsmParser &parser) { - return failure(parser.parseKeyword("to") || parser.parseKeyword("offset") || - parser.parseColon()); - }; - auto parseSizePrefix = [](OpAsmParser &parser) { - return failure(parser.parseComma() || parser.parseKeyword("sizes") || - parser.parseColon()); - }; - auto parseStridePrefix = [](OpAsmParser &parser) { - return failure(parser.parseComma() || parser.parseKeyword("strides") || - parser.parseColon()); - }; - - Type srcType, dstType; - auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) { - return failure(parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.parseKeywordType("to", dstType) || - parser.resolveOperand(srcInfo, srcType, result.operands)); - }; - if (failed(parseOffsetsSizesAndStrides(parser, result, - /*segmentSizes=*/{1}, // source memref - preResolutionFn, parseOffsetPrefix, - parseSizePrefix, parseStridePrefix))) - return failure(); - return parser.addTypeToList(dstType, result.types); -} - // TODO: ponder whether we want to allow missing trailing sizes/strides that are // completed automatically, like we have for subview and subtensor. static LogicalResult verify(MemRefReinterpretCastOp op) { @@ -2892,45 +2831,6 @@ sourceMemRefType.getMemorySpace()); } -/// Print a subview op of the form: -/// ``` -/// `subview` ssa-name -/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` -/// `:` strided-memref-type `to` strided-memref-type -/// ``` -static void print(OpAsmPrinter &p, SubViewOp op) { - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op->getName().getStringRef().drop_front(stdDotLen) << ' '; - p << op.source(); - printOffsetsSizesAndStrides(p, op); - p << " : " << op.source().getType() << " to " << op.getType(); -} - -/// Parse a subview op of the form: -/// ``` -/// `subview` ssa-name -/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` -/// `:` strided-memref-type `to` strided-memref-type -/// ``` -static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType srcInfo; - if (parser.parseOperand(srcInfo)) - return failure(); - Type srcType, dstType; - auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) { - return failure(parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.parseKeywordType("to", dstType) || - parser.resolveOperand(srcInfo, srcType, result.operands)); - }; - - if (failed(parseOffsetsSizesAndStrides(parser, result, - /*segmentSizes=*/{1}, // source memref - preResolutionFn))) - return failure(); - return parser.addTypeToList(dstType, result.types); -} - // Build a SubViewOp with mixed static and dynamic entries and custom result // type. If the type passed is nullptr, it is inferred. void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, @@ -3466,46 +3366,6 @@ // SubTensorOp //===----------------------------------------------------------------------===// -/// Print a subtensor op of the form: -/// ``` -/// `subtensor` ssa-name -/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` -/// `:` ranked-tensor-type `to` ranked-tensor-type -/// ``` -static void print(OpAsmPrinter &p, SubTensorOp op) { - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op->getName().getStringRef().drop_front(stdDotLen) << ' '; - p << op.source(); - printOffsetsSizesAndStrides(p, op); - p << " : " << op.getSourceType() << " to " << op.getType(); -} - -/// Parse a subtensor op of the form: -/// ``` -/// `subtensor` ssa-name -/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` -/// `:` ranked-tensor-type `to` ranked-tensor-type -/// ``` -static ParseResult parseSubTensorOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType srcInfo; - if (parser.parseOperand(srcInfo)) - return failure(); - Type srcType, dstType; - auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) { - return failure(parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.parseKeywordType("to", dstType) || - parser.resolveOperand(srcInfo, srcType, result.operands)); - }; - - if (failed(parseOffsetsSizesAndStrides(parser, result, - /*segmentSizes=*/{1}, // source tensor - preResolutionFn))) - return failure(); - return parser.addTypeToList(dstType, result.types); -} - /// A subtensor result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. @@ -3612,49 +3472,6 @@ // SubTensorInsertOp //===----------------------------------------------------------------------===// -/// Print a subtensor_insert op of the form: -/// ``` -/// `subtensor_insert` ssa-name `into` ssa-name -/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` -/// `:` ranked-tensor-type `into` ranked-tensor-type -/// ``` -static void print(OpAsmPrinter &p, SubTensorInsertOp op) { - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op->getName().getStringRef().drop_front(stdDotLen) << ' '; - p << op.source() << " into " << op.dest(); - printOffsetsSizesAndStrides(p, op); - p << " : " << op.getSourceType() << " into " << op.getType(); -} - -/// Parse a subtensor_insert op of the form: -/// ``` -/// `subtensor_insert` ssa-name `into` ssa-name -/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` -/// `:` ranked-tensor-type `into` ranked-tensor-type -/// ``` -static ParseResult parseSubTensorInsertOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType srcInfo, dstInfo; - if (parser.parseOperand(srcInfo) || parser.parseKeyword("into") || - parser.parseOperand(dstInfo)) - return failure(); - Type srcType, dstType; - auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) { - return failure(parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.parseKeywordType("into", dstType) || - parser.resolveOperand(srcInfo, srcType, result.operands) || - parser.resolveOperand(dstInfo, dstType, result.operands)); - }; - - if (failed(parseOffsetsSizesAndStrides( - parser, result, - /*segmentSizes=*/{1, 1}, // source tensor, destination tensor - preResolutionFn))) - return failure(); - return parser.addTypeToList(dstType, result.types); -} - // Build a SubTensorInsertOp with mixed static and dynamic entries. void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, @@ -3691,13 +3508,6 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } -/// Verifier for SubViewOp. -static LogicalResult verify(SubTensorInsertOp op) { - if (op.getType() != op.dest().getType()) - return op.emitError("expected result type to be ") << op.dest().getType(); - return success(); -} - //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -69,14 +69,18 @@ return success(); } -void mlir::printListOfOperandsOrIntegers( - OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, - llvm::function_ref isDynamic) { +template +static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values, + ArrayAttr arrayAttr) { p << '['; + if (arrayAttr.empty()) { + p << "]"; + return; + } unsigned idx = 0; llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { int64_t val = a.cast().getInt(); - if (isDynamic(val)) + if (val == dynVal) p << values[idx++]; else p << val; @@ -84,32 +88,31 @@ p << ']'; } -void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p, - OffsetSizeAndStrideOpInterface op, - StringRef offsetPrefix, - StringRef sizePrefix, - StringRef stridePrefix, - ArrayRef elidedAttrs) { - p << offsetPrefix; - printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset); - p << sizePrefix; - printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), - ShapedType::isDynamic); - p << stridePrefix; - printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset); - p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); +void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p, + Operation *op, + OperandRange values, + ArrayAttr integers) { + return printOperandsOrIntegersListImpl( + p, values, integers); +} + +void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op, + OperandRange values, + ArrayAttr integers) { + return printOperandsOrIntegersListImpl(p, values, + integers); } -ParseResult mlir::parseListOfOperandsOrIntegers( - OpAsmParser &parser, OperationState &result, StringRef attrName, - int64_t dynVal, SmallVectorImpl &ssa) { +template +static ParseResult +parseOperandsOrIntegersImpl(OpAsmParser &parser, + SmallVectorImpl &values, + ArrayAttr &integers) { if (failed(parser.parseLSquare())) return failure(); // 0-D. if (succeeded(parser.parseOptionalRSquare())) { - result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); + integers = parser.getBuilder().getArrayAttr({}); return success(); } @@ -118,7 +121,7 @@ OpAsmParser::OperandType operand; auto res = parser.parseOptionalOperand(operand); if (res.hasValue() && succeeded(res.getValue())) { - ssa.push_back(operand); + values.push_back(operand); attrVals.push_back(dynVal); } else { IntegerAttr attr; @@ -134,59 +137,20 @@ return failure(); break; } - - auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); - result.addAttribute(attrName, arrayAttr); + integers = parser.getBuilder().getI64ArrayAttr(attrVals); return success(); } -ParseResult mlir::parseOffsetsSizesAndStrides( - OpAsmParser &parser, OperationState &result, ArrayRef segmentSizes, - llvm::function_ref parseOptionalOffsetPrefix, - llvm::function_ref parseOptionalSizePrefix, - llvm::function_ref parseOptionalStridePrefix) { - return parseOffsetsSizesAndStrides( - parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix, - parseOptionalSizePrefix, parseOptionalStridePrefix); +ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList( + OpAsmParser &parser, SmallVectorImpl &values, + ArrayAttr &integers) { + return parseOperandsOrIntegersImpl( + parser, values, integers); } -ParseResult mlir::parseOffsetsSizesAndStrides( - OpAsmParser &parser, OperationState &result, ArrayRef segmentSizes, - llvm::function_ref - preResolutionFn, - llvm::function_ref parseOptionalOffsetPrefix, - llvm::function_ref parseOptionalSizePrefix, - llvm::function_ref parseOptionalStridePrefix) { - SmallVector offsetsInfo, sizesInfo, stridesInfo; - auto indexType = parser.getBuilder().getIndexType(); - if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) || - parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), - ShapedType::kDynamicStrideOrOffset, offsetsInfo) || - (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) || - parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), - ShapedType::kDynamicSize, sizesInfo) || - (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) || - parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), - ShapedType::kDynamicStrideOrOffset, stridesInfo)) - return failure(); - // Add segment sizes to result - SmallVector segmentSizesFinal(segmentSizes.begin(), - segmentSizes.end()); - segmentSizesFinal.append({static_cast(offsetsInfo.size()), - static_cast(sizesInfo.size()), - static_cast(stridesInfo.size())}); - result.addAttribute( - OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), - parser.getBuilder().getI32VectorAttr(segmentSizesFinal)); - return failure( - (preResolutionFn && preResolutionFn(parser, result)) || - parser.resolveOperands(offsetsInfo, indexType, result.operands) || - parser.resolveOperands(sizesInfo, indexType, result.operands) || - parser.resolveOperands(stridesInfo, indexType, result.operands)); +ParseResult mlir::parseOperandsOrIntegersSizesList( + OpAsmParser &parser, SmallVectorImpl &values, + ArrayAttr &integers) { + return parseOperandsOrIntegersImpl(parser, values, + integers); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -643,7 +643,7 @@ // ----- func @pad_no_block(%arg0: tensor, %arg1: i32) -> tensor { - // expected-error @+1 {{expected region with 1 block}} + // expected-error @+1 {{op region #0 ('region') failed to verify constraint: region with 1 blocks}} %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] { } : tensor to tensor return %0 : tensor diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1581,6 +1581,8 @@ if (value) { body << " p << ' ';\n"; lastWasPunctuation = false; + } else { + lastWasPunctuation = true; } shouldEmitSpace = false; }