diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -201,6 +201,202 @@ let hasCanonicalizer = 1; } +// Base class for ops with static/dynamic offset, sizes and strides +// attributes/arguments. +class BaseOpWithOffsetSizesAndStrides traits = []> : + Std_Op { + code extraBaseClassDeclaration = [{ + /// Returns the number of dynamic offset operands. + int64_t getNumOffsets() { return llvm::size(offsets()); } + + /// Returns the number of dynamic size operands. + int64_t getNumSizes() { return llvm::size(sizes()); } + + /// Returns the number of dynamic stride operands. + int64_t getNumStrides() { return llvm::size(strides()); } + + /// Returns the dynamic sizes for this subview operation if specified. + operand_range getDynamicSizes() { return sizes(); } + + /// Returns in `staticStrides` the static value of the stride + /// operands. Returns failure() if the static value of the stride + /// operands could not be retrieved. + LogicalResult getStaticStrides(SmallVectorImpl &staticStrides) { + if (!strides().empty()) + return failure(); + staticStrides.reserve(static_strides().size()); + for (auto s : static_strides().getAsValueRange()) + staticStrides.push_back(s.getZExtValue()); + return success(); + } + + /// Return the list of Range (i.e. offset, size, stride). Each + /// Range entry contains either the dynamic value or a ConstantIndexOp + /// constructed with `b` at location `loc`. + SmallVector getOrCreateRanges(OpBuilder &b, Location loc); + + /// Return the offsets as Values. Each Value is either the dynamic + /// value specified in the op or a ConstantIndexOp constructed + /// with `b` at location `loc` + SmallVector getOrCreateOffsets(OpBuilder &b, Location loc) { + unsigned dynamicIdx = 1; + return llvm::to_vector<4>(llvm::map_range( + static_offsets().cast(), [&](Attribute a) -> Value { + int64_t staticOffset = a.cast().getInt(); + if (ShapedType::isDynamicStrideOrOffset(staticOffset)) + return getOperand(dynamicIdx++); + else + return b.create( + loc, b.getIndexType(), b.getIndexAttr(staticOffset)); + })); + } + + /// Return the sizes as Values. Each Value is either the dynamic + /// value specified in the op or a ConstantIndexOp constructed + /// with `b` at location `loc` + SmallVector getOrCreateSizes(OpBuilder &b, Location loc) { + unsigned dynamicIdx = 1 + offsets().size(); + return llvm::to_vector<4>(llvm::map_range( + static_sizes().cast(), [&](Attribute a) -> Value { + int64_t staticSize = a.cast().getInt(); + if (ShapedType::isDynamic(staticSize)) + return getOperand(dynamicIdx++); + else + return b.create( + loc, b.getIndexType(), b.getIndexAttr(staticSize)); + })); + } + + /// Return the strides as Values. Each Value is either the dynamic + /// value specified in the op or a ConstantIndexOp constructed with + /// `b` at location `loc` + SmallVector getOrCreateStrides(OpBuilder &b, Location loc) { + unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); + return llvm::to_vector<4>(llvm::map_range( + static_strides().cast(), [&](Attribute a) -> Value { + int64_t staticStride = a.cast().getInt(); + if (ShapedType::isDynamicStrideOrOffset(staticStride)) + return getOperand(dynamicIdx++); + else + return b.create( + loc, b.getIndexType(), b.getIndexAttr(staticStride)); + })); + } + + /// Return the rank of the source ShapedType. + unsigned getSourceRank() { + return source().getType().cast().getRank(); + } + + /// Return the rank of the result ShapedType. + unsigned getResultRank() { return getType().getRank(); } + + /// Return true if the offset `idx` is a static constant. + bool isDynamicOffset(unsigned idx) { + APInt v = *(static_offsets().getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + } + /// Return true if the size `idx` is a static constant. + bool isDynamicSize(unsigned idx) { + APInt v = *(static_sizes().getAsValueRange().begin() + idx); + return ShapedType::isDynamic(v.getSExtValue()); + } + + /// Return true if the stride `idx` is a static constant. + bool isDynamicStride(unsigned idx) { + APInt v = *(static_strides().getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + } + + /// Assert the offset `idx` is a static constant and return its value. + int64_t getStaticOffset(unsigned idx) { + assert(!isDynamicOffset(idx) && "expected static offset"); + APInt v = *(static_offsets().getAsValueRange().begin() + idx); + return v.getSExtValue(); + } + /// Assert the size `idx` is a static constant and return its value. + int64_t getStaticSize(unsigned idx) { + assert(!isDynamicSize(idx) && "expected static size"); + APInt v = *(static_sizes().getAsValueRange().begin() + idx); + return v.getSExtValue(); + } + /// Assert the stride `idx` is a static constant and return its value. + int64_t getStaticStride(unsigned idx) { + assert(!isDynamicStride(idx) && "expected static stride"); + APInt v = *(static_strides().getAsValueRange().begin() + idx); + return v.getSExtValue(); + } + + unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr, + llvm::function_ref isDynamic, unsigned idx) { + return std::count_if( + attr.getValue().begin(), attr.getValue().begin() + idx, + [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + } + /// Assert the offset `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicOffset(unsigned idx) { + assert(isDynamicOffset(idx) && "expected static offset"); + auto numDynamic = + getNumDynamicEntriesUpToIdx(static_offsets().cast(), + ShapedType::isDynamicStrideOrOffset, idx); + return 1 + numDynamic; + } + /// Assert the size `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicSize(unsigned idx) { + assert(isDynamicSize(idx) && "expected static size"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_sizes().cast(), ShapedType::isDynamic, idx); + return 1 + offsets().size() + numDynamic; + } + /// Assert the stride `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicStride(unsigned idx) { + assert(isDynamicStride(idx) && "expected static stride"); + auto numDynamic = + getNumDynamicEntriesUpToIdx(static_strides().cast(), + ShapedType::isDynamicStrideOrOffset, idx); + return 1 + offsets().size() + sizes().size() + numDynamic; + } + + /// Assert the offset `idx` is dynamic and return its value. + Value getDynamicOffset(unsigned idx) { + return getOperand(getIndexOfDynamicOffset(idx)); + } + /// Assert the size `idx` is dynamic and return its value. + Value getDynamicSize(unsigned idx) { + return getOperand(getIndexOfDynamicSize(idx)); + } + /// Assert the stride `idx` is dynamic and return its value. + Value getDynamicStride(unsigned idx) { + return getOperand(getIndexOfDynamicStride(idx)); + } + + static StringRef getStaticOffsetsAttrName() { + return "static_offsets"; + } + static StringRef getStaticSizesAttrName() { + return "static_sizes"; + } + static StringRef getStaticStridesAttrName() { + return "static_strides"; + } + static ArrayRef getSpecialAttrNames() { + static SmallVector names{ + getStaticOffsetsAttrName(), + getStaticSizesAttrName(), + getStaticStridesAttrName(), + getOperandSegmentSizeAttr()}; + return names; + } + }]; +} + + //===----------------------------------------------------------------------===// // AbsFOp //===----------------------------------------------------------------------===// @@ -2020,6 +2216,51 @@ }]; } +//===----------------------------------------------------------------------===// +// MemRefReinterpretCastOp +//===----------------------------------------------------------------------===// + +def MemRefReinterpretCastOp: + BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [ + NoSideEffect, ViewLikeOpInterface + ]> { + let summary = "memref reinterpret cast operation"; + let description = [{ + Modify offset, sizes and strides of an unranked/ranked memref. + + Example: + ```mlir + memref_reinterpret_cast %ranked to + offset: [0], + sizes: [%size0, 10], + strides: [1, %stride1] + : memref to memref + + memref_reinterpret_cast %unranked to + offset: [%offset], + sizes: [%size0, %size1], + strides: [%stride0, %stride1] + : memref<*xf32> to memref + ``` + }]; + + let arguments = (ins + Arg:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + I64ArrayAttr:$static_offsets, + I64ArrayAttr:$static_sizes, + I64ArrayAttr:$static_strides + ); + let results = (outs AnyMemRef:$result); + let extraClassDeclaration = extraBaseClassDeclaration # [{ + // The result of the op is always a ranked memref. + MemRefType getType() { return getResult().getType().cast(); } + Value getViewSource() { return source(); } + }]; +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -2710,212 +2951,6 @@ // SubViewOp //===----------------------------------------------------------------------===// -class BaseOpWithOffsetSizesAndStrides traits = []> : - Std_Op { - let builders = [ - // Build a SubViewOp with mixed static and dynamic entries. - OpBuilder< - "Value source, ArrayRef staticOffsets, " - "ArrayRef staticSizes, ArrayRef staticStrides, " - "ValueRange offsets, ValueRange sizes, ValueRange strides, " - "ArrayRef attrs = {}">, - // Build a SubViewOp with all dynamic entries. - OpBuilder< - "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, " - "ArrayRef attrs = {}"> - ]; - - code extraBaseClassDeclaration = [{ - /// Returns the number of dynamic offset operands. - int64_t getNumOffsets() { return llvm::size(offsets()); } - - /// Returns the number of dynamic size operands. - int64_t getNumSizes() { return llvm::size(sizes()); } - - /// Returns the number of dynamic stride operands. - int64_t getNumStrides() { return llvm::size(strides()); } - - /// Returns the dynamic sizes for this subview operation if specified. - operand_range getDynamicSizes() { return sizes(); } - - /// Returns in `staticStrides` the static value of the stride - /// operands. Returns failure() if the static value of the stride - /// operands could not be retrieved. - LogicalResult getStaticStrides(SmallVectorImpl &staticStrides) { - if (!strides().empty()) - return failure(); - staticStrides.reserve(static_strides().size()); - for (auto s : static_strides().getAsValueRange()) - staticStrides.push_back(s.getZExtValue()); - return success(); - } - - /// Return the list of Range (i.e. offset, size, stride). Each - /// Range entry contains either the dynamic value or a ConstantIndexOp - /// constructed with `b` at location `loc`. - SmallVector getOrCreateRanges(OpBuilder &b, Location loc); - - /// Return the offsets as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateOffsets(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1; - return llvm::to_vector<4>(llvm::map_range( - static_offsets().cast(), [&](Attribute a) -> Value { - int64_t staticOffset = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticOffset)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticOffset)); - })); - } - - /// Return the sizes as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateSizes(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size(); - return llvm::to_vector<4>(llvm::map_range( - static_sizes().cast(), [&](Attribute a) -> Value { - int64_t staticSize = a.cast().getInt(); - if (ShapedType::isDynamic(staticSize)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticSize)); - })); - } - - /// Return the strides as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed with - /// `b` at location `loc` - SmallVector getOrCreateStrides(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); - return llvm::to_vector<4>(llvm::map_range( - static_strides().cast(), [&](Attribute a) -> Value { - int64_t staticStride = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticStride)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticStride)); - })); - } - - /// Return the rank of the source ShapedType. - unsigned getSourceRank() { - return source().getType().cast().getRank(); - } - - /// Return the rank of the result ShapedType. - unsigned getResultRank() { return getType().getRank(); } - - /// Return true if the offset `idx` is a static constant. - bool isDynamicOffset(unsigned idx) { - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - /// Return true if the size `idx` is a static constant. - bool isDynamicSize(unsigned idx) { - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return ShapedType::isDynamic(v.getSExtValue()); - } - - /// Return true if the stride `idx` is a static constant. - bool isDynamicStride(unsigned idx) { - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - - /// Assert the offset `idx` is a static constant and return its value. - int64_t getStaticOffset(unsigned idx) { - assert(!isDynamicOffset(idx) && "expected static offset"); - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the size `idx` is a static constant and return its value. - int64_t getStaticSize(unsigned idx) { - assert(!isDynamicSize(idx) && "expected static size"); - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the stride `idx` is a static constant and return its value. - int64_t getStaticStride(unsigned idx) { - assert(!isDynamicStride(idx) && "expected static stride"); - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - - unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr, - llvm::function_ref isDynamic, unsigned idx) { - return std::count_if( - attr.getValue().begin(), attr.getValue().begin() + idx, - [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - } - /// Assert the offset `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicOffset(unsigned idx) { - assert(isDynamicOffset(idx) && "expected static offset"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_offsets().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + numDynamic; - } - /// Assert the size `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicSize(unsigned idx) { - assert(isDynamicSize(idx) && "expected static size"); - auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes().cast(), ShapedType::isDynamic, idx); - return 1 + offsets().size() + numDynamic; - } - /// Assert the stride `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicStride(unsigned idx) { - assert(isDynamicStride(idx) && "expected static stride"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_strides().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + offsets().size() + sizes().size() + numDynamic; - } - - /// Assert the offset `idx` is dynamic and return its value. - Value getDynamicOffset(unsigned idx) { - return getOperand(getIndexOfDynamicOffset(idx)); - } - /// Assert the size `idx` is dynamic and return its value. - Value getDynamicSize(unsigned idx) { - return getOperand(getIndexOfDynamicSize(idx)); - } - /// Assert the stride `idx` is dynamic and return its value. - Value getDynamicStride(unsigned idx) { - return getOperand(getIndexOfDynamicStride(idx)); - } - - static StringRef getStaticOffsetsAttrName() { - return "static_offsets"; - } - static StringRef getStaticSizesAttrName() { - return "static_sizes"; - } - static StringRef getStaticStridesAttrName() { - return "static_strides"; - } - static ArrayRef getSpecialAttrNames() { - static SmallVector names{ - getStaticOffsetsAttrName(), - getStaticSizesAttrName(), - getStaticStridesAttrName(), - getOperandSegmentSizeAttr()}; - return names; - } - }]; -} - def SubViewOp : BaseOpWithOffsetSizesAndStrides< "subview", [DeclareOpInterfaceMethods] > { let summary = "memref subview operation"; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -261,6 +261,126 @@ [](APInt a, APInt b) { return a + b; }); } +//===----------------------------------------------------------------------===// +// BaseOpWithOffsetSizesAndStridesOp +//===----------------------------------------------------------------------===// + +/// Print a list with either (1) the static integer value in `arrayAttr` if +/// `isDynamic` evaluates to false or (2) the next value otherwise. +/// This allows idiomatic printing of mixed value and integer attributes in a +/// list. E.g. `[%arg0, 7, 42, %arg42]`. +static void +printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, + ArrayAttr arrayAttr, + llvm::function_ref isDynamic) { + p << '['; + unsigned idx = 0; + llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { + int64_t val = a.cast().getInt(); + if (isDynamic(val)) + p << values[idx++]; + else + p << val; + }); + p << ']'; +} + +/// Parse a mixed list with either (1) static integer values or (2) SSA values. +/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` +/// encode the position of SSA values. Add the parsed SSA values to `ssa` +/// in-order. +// +/// E.g. after parsing "[%arg0, 7, 42, %arg42]": +/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg1]". +static ParseResult +parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, + StringRef attrName, int64_t dynVal, + SmallVectorImpl &ssa) { + if (failed(parser.parseLSquare())) + return failure(); + // 0-D. + if (succeeded(parser.parseOptionalRSquare())) { + result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); + return success(); + } + + SmallVector attrVals; + while (true) { + OpAsmParser::OperandType operand; + auto res = parser.parseOptionalOperand(operand); + if (res.hasValue() && succeeded(res.getValue())) { + ssa.push_back(operand); + attrVals.push_back(dynVal); + } else { + IntegerAttr attr; + if (failed(parser.parseAttribute(attr))) + return parser.emitError(parser.getNameLoc()) + << "expected SSA value or integer"; + attrVals.push_back(attr.getInt()); + } + + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + break; + } + + auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); + result.addAttribute(attrName, arrayAttr); + return success(); +} + +/// Verify that a particular offset/size/stride static attribute is well-formed. +template +static LogicalResult verifyOpWithOffsetSizesAndStridesPart( + OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName, + ArrayAttr attr, llvm::function_ref isDynamic, + ValueRange values) { + /// Check static and dynamic offsets/sizes/strides breakdown. + if (attr.size() != expectedNumElements) + return op.emitError("expected ") + << expectedNumElements << " " << name << " values"; + unsigned expectedNumDynamicEntries = + llvm::count_if(attr.getValue(), [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + if (values.size() != expectedNumDynamicEntries) + return op.emitError("expected ") + << expectedNumDynamicEntries << " dynamic " << name << " values"; + return success(); +} + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + +/// Verify static attributes offsets/sizes/strides. +template +static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { + unsigned srcRank = op.getSourceRank(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "offset", srcRank, op.getStaticOffsetsAttrName(), + op.static_offsets(), ShapedType::isDynamicStrideOrOffset, + op.offsets()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(), + ShapedType::isDynamic, op.sizes()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "stride", srcRank, op.getStaticStridesAttrName(), + op.static_strides(), ShapedType::isDynamicStrideOrOffset, + op.strides()))) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -2145,6 +2265,169 @@ return impl::foldCastOp(*this); } +//===----------------------------------------------------------------------===// +// MemRefReinterpretCastOp +//===----------------------------------------------------------------------===// + +/// Print of the form: +/// ``` +/// `name` ssa-name to +/// offset: `[` offset `]` +/// sizes: `[` size-list `]` +/// strides:`[` stride-list `]` +/// `:` any-memref-type to strided-memref-type +/// ``` +static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) { + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperationName().drop_front(stdDotLen) << " " << op.source() + << " to offset: "; + printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset); + p << ", sizes: "; + printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + p << ", strides: "; + printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset); + p.printOptionalAttrDict( + op.getAttrs(), + /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), + MemRefReinterpretCastOp::getStaticOffsetsAttrName(), + MemRefReinterpretCastOp::getStaticSizesAttrName(), + MemRefReinterpretCastOp::getStaticStridesAttrName()}); + p << ": " << op.source().getType() << " to " << op.getType(); +} + +/// Parse of the form: +/// ``` +/// `name` ssa-name to +/// offset: `[` offset `]` +/// sizes: `[` size-list `]` +/// strides:`[` stride-list `]` +/// `:` any-memref-type to strided-memref-type +/// ``` +static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser, + OperationState &result) { + // Parse `operand` and `offset`. + OpAsmParser::OperandType operand; + if (parser.parseOperand(operand)) + return failure(); + + // Parse offset. + SmallVector offset; + if (parser.parseKeyword("to") || parser.parseKeyword("offset") || + parser.parseColon() || + parseListOfOperandsOrIntegers( + parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(), + ShapedType::kDynamicStrideOrOffset, offset) || + parser.parseComma()) + return failure(); + + // Parse `sizes`. + SmallVector sizes; + if (parser.parseKeyword("sizes") || parser.parseColon() || + parseListOfOperandsOrIntegers( + parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizes) || + parser.parseComma()) + return failure(); + + // Parse `strides`. + SmallVector strides; + if (parser.parseKeyword("strides") || parser.parseColon() || + parseListOfOperandsOrIntegers( + parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(), + ShapedType::kDynamicStrideOrOffset, strides)) + return failure(); + + // Handle segment sizes. + auto b = parser.getBuilder(); + SmallVector segmentSizes = {1, static_cast(offset.size()), + static_cast(sizes.size()), + static_cast(strides.size())}; + result.addAttribute(MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), + + b.getI32VectorAttr(segmentSizes)); + + // Parse types and resolve. + Type indexType = b.getIndexType(); + Type operandType, resultType; + return failure( + (parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(operandType) || parser.parseKeyword("to") || + parser.parseType(resultType) || + parser.resolveOperand(operand, operandType, result.operands) || + parser.resolveOperands(offset, indexType, result.operands) || + parser.resolveOperands(sizes, indexType, result.operands) || + parser.resolveOperands(strides, indexType, result.operands) || + parser.addTypeToList(resultType, result.types))); +} + +static LogicalResult verify(MemRefReinterpretCastOp op) { + // The source and result memrefs should be in the same memory space. + auto srcType = op.source().getType().cast(); + auto resultType = op.getType().cast(); + if (srcType.getMemorySpace() != resultType.getMemorySpace()) + return op.emitError("different memory spaces specified for source type ") + << srcType << " and result memref type " << resultType; + if (srcType.getElementType() != resultType.getElementType()) + return op.emitError("different element types specified for source type ") + << srcType << " and result memref type " << resultType; + + // Verify that dynamic and static offset/sizes/strides arguments/attributes + // are consistent. + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset, op.offsets()))) + return failure(); + unsigned resultRank = op.getResultRank(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "size", resultRank, op.getStaticSizesAttrName(), + op.static_sizes(), ShapedType::isDynamic, op.sizes()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "stride", resultRank, op.getStaticStridesAttrName(), + op.static_strides(), ShapedType::isDynamicStrideOrOffset, + op.strides()))) + return failure(); + + // Extract source offset and strides. + int64_t resultOffset; + SmallVector resultStrides; + if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + return failure(); + + // Match offset in result memref type and in static_offsets attribute. + int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); + if (resultOffset != expectedOffset) + return op.emitError("expected result type with offset = ") + << resultOffset << " instead of " << expectedOffset; + + // Match sizes in result memref type and in static_sizes attribute. + for (auto &en : + llvm::enumerate(llvm::zip(resultType.getShape(), + extractFromI64ArrayAttr(op.static_sizes())))) { + int64_t resultSize = std::get<0>(en.value()); + int64_t expectedSize = std::get<1>(en.value()); + if (resultSize != expectedSize) + return op.emitError("expected result type with size = ") + << expectedSize << " instead of " << resultSize + << " in dim = " << en.index(); + } + + // Match strides in result memref type and in static_strides attribute. + for (auto &en : llvm::enumerate(llvm::zip( + resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { + int64_t resultStride = std::get<0>(en.value()); + int64_t expectedStride = std::get<1>(en.value()); + if (resultStride != expectedStride) + return op.emitError("expected result type with stride = ") + << expectedStride << " instead of " << resultStride + << " in dim = " << en.index(); + } + return success(); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -2542,75 +2825,6 @@ // SubViewOp //===----------------------------------------------------------------------===// -/// Print a list with either (1) the static integer value in `arrayAttr` if -/// `isDynamic` evaluates to false or (2) the next value otherwise. -/// This allows idiomatic printing of mixed value and integer attributes in a -/// list. E.g. `[%arg0, 7, 42, %arg42]`. -static void printSubViewListOfOperandsOrIntegers( - OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, - llvm::function_ref isDynamic) { - p << "["; - unsigned idx = 0; - llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { - int64_t val = a.cast().getInt(); - if (isDynamic(val)) - p << values[idx++]; - else - p << val; - }); - p << "] "; -} - -/// Parse a mixed list with either (1) static integer values or (2) SSA values. -/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` -/// encode the position of SSA values. Add the parsed SSA values to `ssa` -/// in-order. -// -/// E.g. after parsing "[%arg0, 7, 42, %arg42]": -/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" -/// 2. `ssa` is filled with "[%arg0, %arg1]". -static ParseResult -parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, - StringRef attrName, int64_t dynVal, - SmallVectorImpl &ssa) { - if (failed(parser.parseLSquare())) - return failure(); - // 0-D. - if (succeeded(parser.parseOptionalRSquare())) { - result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); - return success(); - } - - SmallVector attrVals; - while (true) { - OpAsmParser::OperandType operand; - auto res = parser.parseOptionalOperand(operand); - if (res.hasValue() && succeeded(res.getValue())) { - ssa.push_back(operand); - attrVals.push_back(dynVal); - } else { - Attribute attr; - NamedAttrList placeholder; - if (failed(parser.parseAttribute(attr, "_", placeholder)) || - !attr.isa()) - return parser.emitError(parser.getNameLoc()) - << "expected SSA value or integer"; - attrVals.push_back(attr.cast().getInt()); - } - - if (succeeded(parser.parseOptionalComma())) - continue; - if (failed(parser.parseRSquare())) - return failure(); - else - break; - } - - auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); - result.addAttribute(attrName, arrayAttr); - return success(); -} - namespace { /// Helpers to write more idiomatic operations. namespace saturated_arith { @@ -2698,12 +2912,15 @@ p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; p << op.source(); printExtraOperands(p, op); - printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset); - printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), - ShapedType::isDynamic); - printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset); + printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset); + p << ' '; + printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + p << ' '; + printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset); + p << ' '; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{OpType::getSpecialAttrNames()}); p << " : " << op.getSourceType() << " " << resultTypeKeyword << " " @@ -2843,33 +3060,6 @@ /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } -/// Verify that a particular offset/size/stride static attribute is well-formed. -template -static LogicalResult verifyOpWithOffsetSizesAndStridesPart( - OpType op, StringRef name, StringRef attrName, ArrayAttr attr, - llvm::function_ref isDynamic, ValueRange values) { - /// Check static and dynamic offsets/sizes/strides breakdown. - if (attr.size() != op.getSourceRank()) - return op.emitError("expected ") - << op.getSourceRank() << " " << name << " values"; - unsigned expectedNumDynamicEntries = - llvm::count_if(attr.getValue(), [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - if (values.size() != expectedNumDynamicEntries) - return op.emitError("expected ") - << expectedNumDynamicEntries << " dynamic " << name << " values"; - return success(); -} - -/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - llvm::Optional> mlir::computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape) { @@ -3005,24 +3195,6 @@ } } -template -static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { - // Verify static attributes offsets/sizes/strides. - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset, op.offsets()))) - return failure(); - - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", op.getStaticSizesAttrName(), op.static_sizes(), - ShapedType::isDynamic, op.sizes()))) - return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", op.getStaticStridesAttrName(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset, op.strides()))) - return failure(); - return success(); -} /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -102,3 +102,82 @@ // expected-error @+1 {{output type 'memref (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref (d0 * s1 + s0 + d1)>>'}} transpose %v (i, j) -> (j, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> } + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_too_many_offsets +func @memref_reinterpret_cast_too_many_offsets(%in: memref) { + // expected-error @+1 {{expected 1 offset values}} + %out = memref_reinterpret_cast %in to + offset: [0, 0], sizes: [10, 10], strides: [10, 1] + : memref to memref<10x10xf32, offset: 0, strides: [10, 1]> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_element_types +func @memref_reinterpret_cast_incompatible_element_types(%in: memref<*xf32>) { + // expected-error @+1 {{different element types specified}} + %out = memref_reinterpret_cast %in to + offset: [0], sizes: [10], strides: [1] + : memref<*xf32> to memref<10xi32, offset: 0, strides: [1]> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_memory_space +func @memref_reinterpret_cast_incompatible_memory_space(%in: memref<*xf32>) { + // expected-error @+1 {{different memory spaces specified}} + %out = memref_reinterpret_cast %in to + offset: [0], sizes: [10], strides: [1] + : memref<*xf32> to memref<10xi32, offset: 0, strides: [1], 2> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch +func @memref_reinterpret_cast_offset_mismatch(%in: memref) { + // expected-error @+1 {{expected result type with offset = 0 instead of 1}} + %out = memref_reinterpret_cast %in to + offset: [1], sizes: [10], strides: [1] + : memref to memref<10xf32> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_size_mismatch +func @memref_reinterpret_cast_size_mismatch(%in: memref<*xf32>) { + // expected-error @+1 {{expected result type with size = 10 instead of 1 in dim = 0}} + %out = memref_reinterpret_cast %in to + offset: [0], sizes: [10], strides: [1] + : memref<*xf32> to memref<1xf32, offset: 0, strides: [1]> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_stride_mismatch +func @memref_reinterpret_cast_offset_mismatch(%in: memref) { + // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}} + %out = memref_reinterpret_cast %in to + offset: [0], sizes: [10], strides: [2] + : memref to memref<10xf32> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_size_mismatch +func @memref_reinterpret_cast_offset_mismatch(%in: memref) { + %c0 = constant 0 : index + %c10 = constant 10 : index + // expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}} + %out = memref_reinterpret_cast %in to + offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] + : memref to memref + return +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -54,3 +54,14 @@ %result = atan2 %arg0, %arg1 : f32 return %result : f32 } + +// CHECK-LABEL: func @memref_reinterpret_cast +func @memref_reinterpret_cast(%in: memref) + -> memref<10x?xf32, offset: ?, strides: [?, 1]> { + %c0 = constant 0 : index + %c10 = constant 10 : index + %out = memref_reinterpret_cast %in to + offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] + : memref to memref<10x?xf32, offset: ?, strides: [?, 1]> + return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> +}