diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -14,7 +14,7 @@ #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ #include "mlir/IR/Builders.h" -#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" namespace mlir { @@ -29,6 +29,37 @@ class OffsetSizeAndStrideOpInterface; LogicalResult verify(OffsetSizeAndStrideOpInterface op); + +/// Parse trailing part of an op of the form: +/// ``` +/// `[` offset-list `]` +/// `[` size-list `]` +/// [` stride-list `]` +/// ``` +/// Each entry in the offset, size and stride list either resolves to an integer +/// constant or an operand of index type. +/// Constants are added to the `result` as named integer array attributes with +/// name `OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName()` (resp. +/// `getStaticSizesAttrName()`, `getStaticStridesAttrName()`). +/// +/// Append the number of offset, size and stride operands to `segmentSizes` +/// before adding it to `result` as the named attribute: +/// `OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()`. +/// +/// Offset, size and stride operands resolution occurs after `preResolutionFn` +/// to give a chance to leading operands to resolve first, after parsing the +/// types. +ParseResult parseOffsetsSizesAndStrides( + OpAsmParser &parser, OperationState &result, ArrayRef segmentSizes, + llvm::function_ref + preResolutionFn = nullptr, + llvm::function_ref parseOptionalOffsetPrefix = + nullptr, + llvm::function_ref parseOptionalSizePrefix = + nullptr, + llvm::function_ref parseOptionalStridePrefix = + nullptr); + } // namespace mlir /// Include the generated interface declarations. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -272,53 +272,6 @@ p << ']'; } -/// Parse a mixed list with either (1) static integer values or (2) SSA values. -/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` -/// encode the position of SSA values. Add the parsed SSA values to `ssa` -/// in-order. -// -/// E.g. after parsing "[%arg0, 7, 42, %arg42]": -/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" -/// 2. `ssa` is filled with "[%arg0, %arg1]". -static ParseResult -parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, - StringRef attrName, int64_t dynVal, - SmallVectorImpl &ssa) { - if (failed(parser.parseLSquare())) - return failure(); - // 0-D. - if (succeeded(parser.parseOptionalRSquare())) { - result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); - return success(); - } - - SmallVector attrVals; - while (true) { - OpAsmParser::OperandType operand; - auto res = parser.parseOptionalOperand(operand); - if (res.hasValue() && succeeded(res.getValue())) { - ssa.push_back(operand); - attrVals.push_back(dynVal); - } else { - IntegerAttr attr; - if (failed(parser.parseAttribute(attr))) - return parser.emitError(parser.getNameLoc()) - << "expected SSA value or integer"; - attrVals.push_back(attr.getInt()); - } - - if (succeeded(parser.parseOptionalComma())) - continue; - if (failed(parser.parseRSquare())) - return failure(); - break; - } - - auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); - result.addAttribute(attrName, arrayAttr); - return success(); -} - /// Verify that a particular offset/size/stride static attribute is well-formed. static LogicalResult verifyOpWithOffsetSizesAndStridesPart( OffsetSizeAndStrideOpInterface op, StringRef name, @@ -2421,7 +2374,7 @@ } /// Build a MemRefReinterpretCastOp with all dynamic entries: `staticOffsets`, -/// `staticSizes` and `staticStrides` are automatically filled with +/// `staticSizes` and `staticStrides` are automatically filled with /// source-memref-rank sentinel values that encode dynamic entries. void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, @@ -2463,9 +2416,9 @@ p << ": " << op.source().getType() << " to " << op.getType(); } -/// Parse of the form: +/// Parse a memref_reinterpret_cast op of the form: /// ``` -/// `name` ssa-name to +/// `memref_reinterpret_cast` ssa-name to /// offset: `[` offset `]` /// sizes: `[` size-list `]` /// strides:`[` stride-list `]` @@ -2473,62 +2426,37 @@ /// ``` static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser, OperationState &result) { - // Parse `operand` and `offset`. - OpAsmParser::OperandType operand; - if (parser.parseOperand(operand)) - return failure(); - - // Parse offset. - SmallVector offset; - if (parser.parseKeyword("to") || parser.parseKeyword("offset") || - parser.parseColon() || - parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), - ShapedType::kDynamicStrideOrOffset, offset) || - parser.parseComma()) + // Parse `operand` + OpAsmParser::OperandType srcInfo; + if (parser.parseOperand(srcInfo)) return failure(); - // Parse `sizes`. - SmallVector sizes; - if (parser.parseKeyword("sizes") || parser.parseColon() || - parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), - ShapedType::kDynamicSize, sizes) || - parser.parseComma()) - return failure(); + auto parseOffsetPrefix = [](OpAsmParser &parser) { + return failure(parser.parseKeyword("to") || parser.parseKeyword("offset") || + parser.parseColon()); + }; + auto parseSizePrefix = [](OpAsmParser &parser) { + return failure(parser.parseComma() || parser.parseKeyword("sizes") || + parser.parseColon()); + }; + auto parseStridePrefix = [](OpAsmParser &parser) { + return failure(parser.parseComma() || parser.parseKeyword("strides") || + parser.parseColon()); + }; - // Parse `strides`. - SmallVector strides; - if (parser.parseKeyword("strides") || parser.parseColon() || - parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), - ShapedType::kDynamicStrideOrOffset, strides)) + Type srcType, dstType; + auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) { + return failure(parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.parseKeywordType("to", dstType) || + parser.resolveOperand(srcInfo, srcType, result.operands)); + }; + SmallVector segmentSizes{1}; // source memref + if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes, + preResolutionFn, parseOffsetPrefix, + parseSizePrefix, parseStridePrefix))) return failure(); - - // Handle segment sizes. - auto b = parser.getBuilder(); - SmallVector segmentSizes = {1, static_cast(offset.size()), - static_cast(sizes.size()), - static_cast(strides.size())}; - result.addAttribute(MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), - - b.getI32VectorAttr(segmentSizes)); - - // Parse types and resolve. - Type indexType = b.getIndexType(); - Type operandType, resultType; - return failure( - (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(operandType) || parser.parseKeyword("to") || - parser.parseType(resultType) || - parser.resolveOperand(operand, operandType, result.operands) || - parser.resolveOperands(offset, indexType, result.operands) || - parser.resolveOperands(sizes, indexType, result.operands) || - parser.resolveOperands(strides, indexType, result.operands) || - parser.addTypeToList(resultType, result.types))); + return parser.addTypeToList(dstType, result.types); } static LogicalResult verify(MemRefReinterpretCastOp op) { @@ -3228,67 +3156,28 @@ return printOpWithOffsetsSizesAndStrides(p, op); } -/// Parse of the form: +/// Parse a subview op of the form: /// ``` -/// `name` ssa-name (extra-operands)? -/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` -/// `:` strided-memref-type `resultTypeKeyword strided-memref-type +/// `subview` ssa-name +/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` strided-memref-type `to` strided-memref-type /// ``` -template -static ParseResult parseOpWithOffsetsSizesAndStrides( - OpAsmParser &parser, OperationState &result, - std::function - parseExtraOperand = nullptr, - StringRef resultTypeKeyword = "to") { - OpAsmParser::OperandType srcInfo, dstInfo; - SmallVector offsetsInfo, sizesInfo, stridesInfo; - auto indexType = parser.getBuilder().getIndexType(); - Type srcType, dstType; +static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; if (parser.parseOperand(srcInfo)) return failure(); - if (parseExtraOperand && parseExtraOperand(parser, dstInfo)) - return failure(); - if (parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), - ShapedType::kDynamicStrideOrOffset, offsetsInfo) || - parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), - ShapedType::kDynamicSize, sizesInfo) || - parseListOfOperandsOrIntegers( - parser, result, - OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), - ShapedType::kDynamicStrideOrOffset, stridesInfo)) + Type srcType, dstType; + auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) { + return failure(parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.parseKeywordType("to", dstType) || + parser.resolveOperand(srcInfo, srcType, result.operands)); + }; + SmallVector segmentSizes{1}; // source memref + if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes, + preResolutionFn))) return failure(); - - // Handle segment sizes. - auto b = parser.getBuilder(); - SmallVector segmentSizes = {1, static_cast(offsetsInfo.size()), - static_cast(sizesInfo.size()), - static_cast(stridesInfo.size())}; - // If we parse an extra operand it needs to appear in the segmentSizes - if (parseExtraOperand) - segmentSizes.insert(segmentSizes.begin(), 1); - result.addAttribute(OpType::getOperandSegmentSizeAttr(), - b.getI32VectorAttr(segmentSizes)); - - return failure( - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.parseKeywordType(resultTypeKeyword.str().c_str(), dstType) || - parser.resolveOperand(srcInfo, srcType, result.operands) || - (parseExtraOperand && - parser.resolveOperand(dstInfo, dstType, result.operands)) || - parser.resolveOperands(offsetsInfo, indexType, result.operands) || - parser.resolveOperands(sizesInfo, indexType, result.operands) || - parser.resolveOperands(stridesInfo, indexType, result.operands) || - parser.addTypeToList(dstType, result.types)); -} - -static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { - return parseOpWithOffsetsSizesAndStrides(parser, result); + return parser.addTypeToList(dstType, result.types); } void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, @@ -3307,7 +3196,7 @@ } /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes` -/// and `staticStrides` are automatically filled with source-memref-rank +/// and `staticStrides` are automatically filled with source-memref-rank /// sentinel values that encode dynamic entries. void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, @@ -3865,9 +3754,29 @@ return printOpWithOffsetsSizesAndStrides(p, op); } +/// Parse a subtensor op of the form: +/// ``` +/// `subtensor` ssa-name +/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` ranked-tensor-type `to` ranked-tensor-type +/// ``` static ParseResult parseSubTensorOp(OpAsmParser &parser, OperationState &result) { - return parseOpWithOffsetsSizesAndStrides(parser, result); + OpAsmParser::OperandType srcInfo; + if (parser.parseOperand(srcInfo)) + return failure(); + Type srcType, dstType; + auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) { + return failure(parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.parseKeywordType("to", dstType) || + parser.resolveOperand(srcInfo, srcType, result.operands)); + }; + SmallVector segmentSizes{1}; // source tensor + if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes, + preResolutionFn))) + return failure(); + return parser.addTypeToList(dstType, result.types); } /// A subtensor result type can be fully inferred from the source type and the @@ -3951,15 +3860,31 @@ /*resultTypeKeyword=*/"into"); } +/// Parse a subtensor_insert op of the form: +/// ``` +/// `subtensor_insert` ssa-name `into` ssa-name +/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` ranked-tensor-type `into` ranked-tensor-type +/// ``` static ParseResult parseSubTensorInsertOp(OpAsmParser &parser, OperationState &result) { - return parseOpWithOffsetsSizesAndStrides( - parser, result, - [](OpAsmParser &parser, OpAsmParser::OperandType &dstInfo) { - return failure(parser.parseKeyword("into") || - parser.parseOperand(dstInfo)); - }, - "into"); + OpAsmParser::OperandType srcInfo, dstInfo; + if (parser.parseOperand(srcInfo) || parser.parseKeyword("into") || + parser.parseOperand(dstInfo)) + return failure(); + Type srcType, dstType; + auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) { + return failure(parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.parseKeywordType("into", dstType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperand(dstInfo, dstType, result.operands)); + }; + SmallVector segmentSizes{1, 1}; // source tensor, destination tensor + if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes, + preResolutionFn))) + return failure(); + return parser.addTypeToList(dstType, result.types); } void mlir::SubTensorInsertOp::build( @@ -3974,7 +3899,7 @@ } /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes` -/// and `staticStrides` are automatically filled with source-memref-rank +/// and `staticStrides` are automatically filled with source-memref-rank /// sentinel values that encode dynamic entries. void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -8,8 +8,6 @@ #include "mlir/Interfaces/ViewLikeInterface.h" -#include "mlir/IR/StandardTypes.h" - using namespace mlir; //===----------------------------------------------------------------------===// @@ -58,3 +56,93 @@ return failure(); return success(); } + +/// Parse a mixed list with either (1) static integer values or (2) SSA values. +/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` +/// encode the position of SSA values. Add the parsed SSA values to `ssa` +/// in-order. +// +/// E.g. after parsing "[%arg0, 7, 42, %arg42]": +/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg1]". +static ParseResult +parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, + StringRef attrName, int64_t dynVal, + SmallVectorImpl &ssa) { + if (failed(parser.parseLSquare())) + return failure(); + // 0-D. + if (succeeded(parser.parseOptionalRSquare())) { + result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); + return success(); + } + + SmallVector attrVals; + while (true) { + OpAsmParser::OperandType operand; + auto res = parser.parseOptionalOperand(operand); + if (res.hasValue() && succeeded(res.getValue())) { + ssa.push_back(operand); + attrVals.push_back(dynVal); + } else { + IntegerAttr attr; + if (failed(parser.parseAttribute(attr))) + return parser.emitError(parser.getNameLoc()) + << "expected SSA value or integer"; + attrVals.push_back(attr.getInt()); + } + + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + break; + } + + auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); + result.addAttribute(attrName, arrayAttr); + return success(); +} + +ParseResult mlir::parseOffsetsSizesAndStrides( + OpAsmParser &parser, + OperationState &result, + ArrayRef segmentSizes, + llvm::function_ref + preResolutionFn, + llvm::function_ref parseOptionalOffsetPrefix, + llvm::function_ref parseOptionalSizePrefix, + llvm::function_ref parseOptionalStridePrefix) { + SmallVector offsetsInfo, sizesInfo, stridesInfo; + auto indexType = parser.getBuilder().getIndexType(); + if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) || + parseListOfOperandsOrIntegers( + parser, result, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), + ShapedType::kDynamicStrideOrOffset, offsetsInfo) || + (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) || + parseListOfOperandsOrIntegers( + parser, result, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizesInfo) || + (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) || + parseListOfOperandsOrIntegers( + parser, result, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), + ShapedType::kDynamicStrideOrOffset, stridesInfo)) + return failure(); + // Add segment sizes to result + SmallVector segmentSizesFinal(segmentSizes.begin(), segmentSizes.end()); + segmentSizesFinal.append({static_cast(offsetsInfo.size()), + static_cast(sizesInfo.size()), + static_cast(stridesInfo.size())}); + auto b = parser.getBuilder(); + result.addAttribute( + OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), + b.getI32VectorAttr(segmentSizesFinal)); + return failure( + (preResolutionFn && preResolutionFn(parser, result)) || + parser.resolveOperands(offsetsInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.resolveOperands(stridesInfo, indexType, result.operands)); +}