diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2217,6 +2217,50 @@ }]; } +//===----------------------------------------------------------------------===// +// MemRefReinterpretCastOp +//===----------------------------------------------------------------------===// + +def MemRefReinterpretCastOp: + BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [ + NoSideEffect, DeclareOpInterfaceMethods + ]> { + let summary = "memref reinterpret cast operation"; + let description = [{ + Modify offset, sizes and strides of an unranked/ranked memref. + + Example: + ```mlir + memref_reinterpret_cast %ranked to + offset: [0], + sizes: [%size0, 10], + strides: [1, %stride1] + : memref to memref + + memref_reinterpret_cast %unranked to + offset: [%offset], + sizes: [%size0, %size1], + strides: [%stride0, %stride1] + : memref<*xf32> to memref + ``` + }]; + + let arguments = (ins + Arg:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + I64ArrayAttr:$static_offsets, + I64ArrayAttr:$static_sizes, + I64ArrayAttr:$static_strides + ); + let results = (outs AnyMemRef:$result); + let extraClassDeclaration = extraBaseClassDeclaration # [{ + // The result of the op is always a ranked memref. + MemRefType getType() { return getResult().getType().cast(); } + }]; +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -261,6 +261,128 @@ [](APInt a, APInt b) { return a + b; }); } +//===----------------------------------------------------------------------===// +// BaseOpWithOffsetSizesAndStridesOp +//===----------------------------------------------------------------------===// + +/// Print a list with either (1) the static integer value in `arrayAttr` if +/// `isDynamic` evaluates to false or (2) the next value otherwise. +/// This allows idiomatic printing of mixed value and integer attributes in a +/// list. E.g. `[%arg0, 7, 42, %arg42]`. +static void +printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, + ArrayAttr arrayAttr, + llvm::function_ref isDynamic) { + p << '['; + unsigned idx = 0; + llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { + int64_t val = a.cast().getInt(); + if (isDynamic(val)) + p << values[idx++]; + else + p << val; + }); + p << ']'; +} + +/// Parse a mixed list with either (1) static integer values or (2) SSA values. +/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` +/// encode the position of SSA values. Add the parsed SSA values to `ssa` +/// in-order. +// +/// E.g. after parsing "[%arg0, 7, 42, %arg42]": +/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg1]". +static ParseResult +parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, + StringRef attrName, int64_t dynVal, + SmallVectorImpl &ssa) { + if (failed(parser.parseLSquare())) + return failure(); + // 0-D. + if (succeeded(parser.parseOptionalRSquare())) { + result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); + return success(); + } + + SmallVector attrVals; + while (true) { + OpAsmParser::OperandType operand; + auto res = parser.parseOptionalOperand(operand); + if (res.hasValue() && succeeded(res.getValue())) { + ssa.push_back(operand); + attrVals.push_back(dynVal); + } else { + Attribute attr; + NamedAttrList placeholder; + if (failed(parser.parseAttribute(attr, "_", placeholder)) || + !attr.isa()) + return parser.emitError(parser.getNameLoc()) + << "expected SSA value or integer"; + attrVals.push_back(attr.cast().getInt()); + } + + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + break; + } + + auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); + result.addAttribute(attrName, arrayAttr); + return success(); +} + +/// Verify that a particular offset/size/stride static attribute is well-formed. +template +static LogicalResult verifyOpWithOffsetSizesAndStridesPart( + OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName, + ArrayAttr attr, llvm::function_ref isDynamic, + ValueRange values) { + /// Check static and dynamic offsets/sizes/strides breakdown. + if (attr.size() != expectedNumElements) + return op.emitError("expected ") + << expectedNumElements << " " << name << " values"; + unsigned expectedNumDynamicEntries = + llvm::count_if(attr.getValue(), [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + if (values.size() != expectedNumDynamicEntries) + return op.emitError("expected ") + << expectedNumDynamicEntries << " dynamic " << name << " values"; + return success(); +} + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + +/// Verify static attributes offsets/sizes/strides. +template +static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { + unsigned srcRank = op.getSourceRank(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "offset", srcRank, op.getStaticOffsetsAttrName(), + op.static_offsets(), ShapedType::isDynamicStrideOrOffset, + op.offsets()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(), + ShapedType::isDynamic, op.sizes()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "stride", srcRank, op.getStaticStridesAttrName(), + op.static_strides(), ShapedType::isDynamicStrideOrOffset, + op.strides()))) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -2145,6 +2267,173 @@ return impl::foldCastOp(*this); } +//===----------------------------------------------------------------------===// +// MemRefReinterpretCastOp +//===----------------------------------------------------------------------===// + +/// Print of the form: +/// ``` +/// `name` ssa-name to +/// offset: `[` offset `]` +/// sizes: `[` size-list `]` +/// strides:`[` stride-list `]` +/// `:` any-memref-type to strided-memref-type +/// ``` +static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) { + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperationName().drop_front(stdDotLen) << " " << op.source() + << " to offset: "; + printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset); + p << ", sizes: "; + printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + p << ", strides: "; + printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset); + p.printOptionalAttrDict( + op.getAttrs(), + /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), + MemRefReinterpretCastOp::getStaticOffsetsAttrName(), + MemRefReinterpretCastOp::getStaticSizesAttrName(), + MemRefReinterpretCastOp::getStaticStridesAttrName()}); + p << ": " << op.source().getType() << " to " << op.getType(); +} + +/// Parse of the form: +/// ``` +/// `name` ssa-name to +/// offset: `[` offset `]` +/// sizes: `[` size-list `]` +/// strides:`[` stride-list `]` +/// `:` any-memref-type to strided-memref-type +/// ``` +static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser, + OperationState &result) { + // Parse `operand` and `offset`. + OpAsmParser::OperandType operand; + if (parser.parseOperand(operand)) + return failure(); + + // Parse offset. + SmallVector offset; + if (parser.parseKeyword("to") || parser.parseKeyword("offset") || + parser.parseColon() || + parseListOfOperandsOrIntegers( + parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(), + ShapedType::kDynamicStrideOrOffset, offset) || + parser.parseComma()) + return failure(); + + // Parse `sizes`. + SmallVector sizes; + if (parser.parseKeyword("sizes") || parser.parseColon() || + parseListOfOperandsOrIntegers( + parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizes) || + parser.parseComma()) + return failure(); + + // Parse `strides`. + SmallVector strides; + if (parser.parseKeyword("strides") || parser.parseColon() || + parseListOfOperandsOrIntegers( + parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(), + ShapedType::kDynamicStrideOrOffset, strides)) + return failure(); + + // Handle segment sizes. + auto b = parser.getBuilder(); + SmallVector segmentSizes = {1, static_cast(offset.size()), + static_cast(sizes.size()), + static_cast(strides.size())}; + result.addAttribute(MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), + + b.getI32VectorAttr(segmentSizes)); + + // Parse types and resolve. + Type indexType = b.getIndexType(); + Type operandType, resultType; + return failure( + (parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(operandType) || parser.parseKeyword("to") || + parser.parseType(resultType) || + parser.resolveOperand(operand, operandType, result.operands) || + parser.resolveOperands(offset, indexType, result.operands) || + parser.resolveOperands(sizes, indexType, result.operands) || + parser.resolveOperands(strides, indexType, result.operands) || + parser.addTypeToList(resultType, result.types))); +} + +Value MemRefReinterpretCastOp::getViewSource() { return source(); } + +static LogicalResult verify(MemRefReinterpretCastOp op) { + // The source and result memrefs should be in the same memory space. + auto srcType = op.source().getType(); + auto resultType = op.getType().cast(); + if (srcType.cast().getMemorySpace() != + resultType.getMemorySpace()) + return op.emitError("different memory spaces specified for source type ") + << srcType << " and result memref type " << resultType; + if (srcType.cast().getElementType() != + resultType.getElementType()) + return op.emitError("different element types specified for source type ") + << srcType << " and result memref type " << resultType; + + // Verify that dynamic and static offset/sizes/strides arguments/attributes + // are consistent. + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset, op.offsets()))) + return failure(); + unsigned resultRank = op.getResultRank(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "size", resultRank, op.getStaticSizesAttrName(), + op.static_sizes(), ShapedType::isDynamic, op.sizes()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "stride", resultRank, op.getStaticStridesAttrName(), + op.static_strides(), ShapedType::isDynamicStrideOrOffset, + op.strides()))) + return failure(); + + // Extract source offset and strides. + int64_t resultOffset; + SmallVector resultStrides; + if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + return failure(); + + // Match offset in result memref type and in static_offsets attribute. + int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); + if (resultOffset != expectedOffset) + return op.emitError("expected result type with offset = ") + << resultOffset << " instead of " << expectedOffset; + + // Match sizes in result memref type and in static_sizes attribute. + for (auto &en : + llvm::enumerate(llvm::zip(resultType.getShape(), + extractFromI64ArrayAttr(op.static_sizes())))) { + int64_t resultSize = std::get<0>(en.value()); + int64_t expectedSize = std::get<1>(en.value()); + if (resultSize != expectedSize) + return op.emitError("expected result type with size = ") + << expectedSize << " instead of " << resultSize + << " in dim = " << en.index(); + } + + // Match strides in result memref type and in static_strides attribute. + for (auto &en : llvm::enumerate(llvm::zip( + resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { + int64_t resultStride = std::get<0>(en.value()); + int64_t expectedStride = std::get<1>(en.value()); + if (resultStride != expectedStride) + return op.emitError("expected result type with stride = ") + << expectedStride << " instead of " << resultStride + << " in dim = " << en.index(); + } + return success(); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -2542,75 +2831,6 @@ // SubViewOp //===----------------------------------------------------------------------===// -/// Print a list with either (1) the static integer value in `arrayAttr` if -/// `isDynamic` evaluates to false or (2) the next value otherwise. -/// This allows idiomatic printing of mixed value and integer attributes in a -/// list. E.g. `[%arg0, 7, 42, %arg42]`. -static void printSubViewListOfOperandsOrIntegers( - OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, - llvm::function_ref isDynamic) { - p << "["; - unsigned idx = 0; - llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { - int64_t val = a.cast().getInt(); - if (isDynamic(val)) - p << values[idx++]; - else - p << val; - }); - p << "] "; -} - -/// Parse a mixed list with either (1) static integer values or (2) SSA values. -/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` -/// encode the position of SSA values. Add the parsed SSA values to `ssa` -/// in-order. -// -/// E.g. after parsing "[%arg0, 7, 42, %arg42]": -/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" -/// 2. `ssa` is filled with "[%arg0, %arg1]". -static ParseResult -parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, - StringRef attrName, int64_t dynVal, - SmallVectorImpl &ssa) { - if (failed(parser.parseLSquare())) - return failure(); - // 0-D. - if (succeeded(parser.parseOptionalRSquare())) { - result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); - return success(); - } - - SmallVector attrVals; - while (true) { - OpAsmParser::OperandType operand; - auto res = parser.parseOptionalOperand(operand); - if (res.hasValue() && succeeded(res.getValue())) { - ssa.push_back(operand); - attrVals.push_back(dynVal); - } else { - Attribute attr; - NamedAttrList placeholder; - if (failed(parser.parseAttribute(attr, "_", placeholder)) || - !attr.isa()) - return parser.emitError(parser.getNameLoc()) - << "expected SSA value or integer"; - attrVals.push_back(attr.cast().getInt()); - } - - if (succeeded(parser.parseOptionalComma())) - continue; - if (failed(parser.parseRSquare())) - return failure(); - else - break; - } - - auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); - result.addAttribute(attrName, arrayAttr); - return success(); -} - namespace { /// Helpers to write more idiomatic operations. namespace saturated_arith { @@ -2698,12 +2918,15 @@ p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; p << op.source(); printExtraOperands(p, op); - printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset); - printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), - ShapedType::isDynamic); - printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset); + printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset); + p << ' '; + printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + p << ' '; + printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset); + p << ' '; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{OpType::getSpecialAttrNames()}); p << " : " << op.getSourceType() << " " << resultTypeKeyword << " " @@ -2843,33 +3066,6 @@ /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } -/// Verify that a particular offset/size/stride static attribute is well-formed. -template -static LogicalResult verifyOpWithOffsetSizesAndStridesPart( - OpType op, StringRef name, StringRef attrName, ArrayAttr attr, - llvm::function_ref isDynamic, ValueRange values) { - /// Check static and dynamic offsets/sizes/strides breakdown. - if (attr.size() != op.getSourceRank()) - return op.emitError("expected ") - << op.getSourceRank() << " " << name << " values"; - unsigned expectedNumDynamicEntries = - llvm::count_if(attr.getValue(), [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - if (values.size() != expectedNumDynamicEntries) - return op.emitError("expected ") - << expectedNumDynamicEntries << " dynamic " << name << " values"; - return success(); -} - -/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - llvm::Optional> mlir::computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape) { @@ -3006,24 +3202,6 @@ llvm_unreachable("unexpected subview verification result"); } -template -static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { - // Verify static attributes offsets/sizes/strides. - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset, op.offsets()))) - return failure(); - - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", op.getStaticSizesAttrName(), op.static_sizes(), - ShapedType::isDynamic, op.sizes()))) - return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", op.getStaticStridesAttrName(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset, op.strides()))) - return failure(); - return success(); -} /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -102,3 +102,82 @@ // expected-error @+1 {{output type 'memref (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref (d0 * s1 + s0 + d1)>>'}} transpose %v (i, j) -> (j, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> } + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_too_many_offsets +func @memref_reinterpret_cast_too_many_offsets(%in: memref) { + // expected-error @+1 {{expected 1 offset values}} + %out = memref_reinterpret_cast %in to + offset: [0, 0], sizes: [10, 10], strides: [10, 1] + : memref to memref<10x10xf32, offset: 0, strides: [10, 1]> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_element_types +func @memref_reinterpret_cast_incompatible_element_types(%in: memref<*xf32>) { + // expected-error @+1 {{different element types specified}} + %out = memref_reinterpret_cast %in to + offset: [0], sizes: [10], strides: [1] + : memref<*xf32> to memref<10xi32, offset: 0, strides: [1]> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_memory_space +func @memref_reinterpret_cast_incompatible_memory_space(%in: memref<*xf32>) { + // expected-error @+1 {{different memory spaces specified}} + %out = memref_reinterpret_cast %in to + offset: [0], sizes: [10], strides: [1] + : memref<*xf32> to memref<10xi32, offset: 0, strides: [1], 2> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch +func @memref_reinterpret_cast_offset_mismatch(%in: memref) { + // expected-error @+1 {{expected result type with offset = 0 instead of 1}} + %out = memref_reinterpret_cast %in to + offset: [1], sizes: [10], strides: [1] + : memref to memref<10xf32> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_size_mismatch +func @memref_reinterpret_cast_size_mismatch(%in: memref<*xf32>) { + // expected-error @+1 {{expected result type with size = 10 instead of 1 in dim = 0}} + %out = memref_reinterpret_cast %in to + offset: [0], sizes: [10], strides: [1] + : memref<*xf32> to memref<1xf32, offset: 0, strides: [1]> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_stride_mismatch +func @memref_reinterpret_cast_offset_mismatch(%in: memref) { + // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}} + %out = memref_reinterpret_cast %in to + offset: [0], sizes: [10], strides: [2] + : memref to memref<10xf32> + return +} + +// ----- + +// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_size_mismatch +func @memref_reinterpret_cast_offset_mismatch(%in: memref) { + %c0 = constant 0 : index + %c10 = constant 10 : index + // expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}} + %out = memref_reinterpret_cast %in to + offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] + : memref to memref + return +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -54,3 +54,14 @@ %result = atan2 %arg0, %arg1 : f32 return %result : f32 } + +// CHECK-LABEL: func @memref_reinterpret_cast +func @memref_reinterpret_cast(%in: memref) + -> memref<10x?xf32, offset: ?, strides: [?, 1]> { + %c0 = constant 0 : index + %c10 = constant 10 : index + %out = memref_reinterpret_cast %in to + offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] + : memref to memref<10x?xf32, offset: ?, strides: [?, 1]> + return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> +}