diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3045,7 +3045,7 @@ let extraClassDeclaration = extraBaseClassDeclaration # [{ /// Returns the type of the base memref operand. - MemRefType getSourceMemRefType() { + MemRefType getSourceType() { return source().getType().cast(); } @@ -3101,9 +3101,22 @@ ); let results = (outs AnyRankedTensor:$result); + let builders = [ + // Build a SubViewOp with mized static and dynamic entries. + OpBuilder< + "Value source, ArrayRef staticOffsets, " + "ArrayRef staticSizes, ArrayRef staticStrides, " + "ValueRange offsets, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}">, + // Build a SubViewOp with all dynamic entries. + OpBuilder< + "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}"> + ]; + let extraClassDeclaration = extraBaseClassDeclaration # [{ /// Returns the type of the base tensor operand. - RankedTensorType getSourceRankedTensorType() { + RankedTensorType getSourceType() { return source().getType().cast(); } @@ -3124,6 +3137,77 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// SubTensorInsertOp +//===----------------------------------------------------------------------===// + +def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert"> { + let summary = "subtensor_insert operation"; + let description = [{ + The "subtensor_insert" operation insert a tensor `source` into another + tensor `dest` as specified by the operation's offsets, sizes and strides + arguments. + + It returns a copy of `dest` with the proper subtensor updated with the value + of `source`. + + The subtensor_insert operation supports the following arguments: + + * source: the tensor that is inserted. + * dest: the tensor into which the source tensor is inserted. + * offsets: tensor-rank number of dynamic offsets or static integer + attributes into the "base" tensor from which to extract the + subtensor. + * sizes: tensor-rank number of dynamic sizes or static integer attributes + which specify the sizes of the result tensor type. + * strides: tensor-rank number of dynamic strides or static integer + attributes multiplicatively to the base memref strides in each + dimension. + + After buffer-allocation, the "subtensor_insert" op is expected to ne erased + and be replaced by an in-place buffer update. + }]; + + let arguments = (ins + AnyRankedTensor:$source, + AnyRankedTensor:$dest, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + I64ArrayAttr:$static_offsets, + I64ArrayAttr:$static_sizes, + I64ArrayAttr:$static_strides + ); + let results = (outs AnyRankedTensor:$result); + + let builders = [ + // Build a SubViewOp with mized static and dynamic entries. + OpBuilder< + "Value source, Value dest, ArrayRef staticOffsets, " + "ArrayRef staticSizes, ArrayRef staticStrides, " + "ValueRange offsets, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}">, + // Build a SubViewOp with all dynamic entries. + OpBuilder< + "Value source, Value dest, ValueRange offsets, ValueRange sizes, " + "ValueRange strides, ArrayRef attrs = {}"> + ]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + /// Returns the type of the base tensor operand. + RankedTensorType getSourceType() { + return source().getType().cast(); + } + + /// The result of a subtensor is always a tensor. + RankedTensorType getType() { + return getResult().getType().cast(); + } + }]; + + // let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // TanhOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" @@ -2639,10 +2640,15 @@ /// `:` strided-memref-type `to` strided-memref-type /// ``` template -static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) { +static void printOpWithOffsetsSizesAndStrides( + OpAsmPrinter &p, OpType op, + llvm::function_ref printExtraOperands = + [](OpAsmPrinter &p, OpType op) {}, + StringLiteral resultTypeKeyword = "to") { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; - p << op.getOperand(0); + p << op.source(); + printExtraOperands(p, op); printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), ShapedType::isDynamicStrideOrOffset); printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), @@ -2651,27 +2657,35 @@ ShapedType::isDynamicStrideOrOffset); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{OpType::getSpecialAttrNames()}); - p << " : " << op.getOperand(0).getType() << " to " << op.getType(); + p << " : " << op.getSourceType() << " " << resultTypeKeyword << " " + << op.getType(); } static void print(OpAsmPrinter &p, SubViewOp op) { return printOpWithOffsetsSizesAndStrides(p, op); } -/// Parse SubViewOp of the form: +/// Parse of the form: /// ``` -/// `name` ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` -/// `:` strided-memref-type `to` strided-memref-type +/// `name` ssa-name (extra-operands)? +/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` strided-memref-type `resultTypeKeyword strided-memref-type /// ``` template -static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType srcInfo; +static ParseResult parseOpWithOffsetsSizesAndStrides( + OpAsmParser &parser, OperationState &result, + std::function + parseExtraOperand = nullptr, + StringLiteral resultTypeKeyword = "to") { + OpAsmParser::OperandType srcInfo, dstInfo; SmallVector offsetsInfo, sizesInfo, stridesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; if (parser.parseOperand(srcInfo)) return failure(); + if (parseExtraOperand && parseExtraOperand(parser, dstInfo)) + return failure(); if (parseListOfOperandsOrIntegers( parser, result, OpType::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offsetsInfo) || @@ -2683,21 +2697,27 @@ ShapedType::kDynamicStrideOrOffset, stridesInfo)) return failure(); + // Handle segment sizes. auto b = parser.getBuilder(); - SmallVector segmentSizes{1, static_cast(offsetsInfo.size()), - static_cast(sizesInfo.size()), - static_cast(stridesInfo.size())}; + SmallVector segmentSizes = {1, static_cast(offsetsInfo.size()), + static_cast(sizesInfo.size()), + static_cast(stridesInfo.size())}; + // If we parse an extra operand it needs to appear in the segmentSizes + if (parseExtraOperand) + segmentSizes.insert(segmentSizes.begin(), 1); result.addAttribute(OpType::getOperandSegmentSizeAttr(), b.getI32VectorAttr(segmentSizes)); return failure( parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || + parser.parseKeywordType(resultTypeKeyword.str().c_str(), dstType) || parser.resolveOperand(srcInfo, srcType, result.operands) || + (parseExtraOperand && + parser.resolveOperand(dstInfo, dstType, result.operands)) || parser.resolveOperands(offsetsInfo, indexType, result.operands) || parser.resolveOperands(sizesInfo, indexType, result.operands) || parser.resolveOperands(stridesInfo, indexType, result.operands) || - parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)); } @@ -2790,7 +2810,7 @@ /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { - MemRefType baseType = op.getSourceMemRefType(); + MemRefType baseType = op.getSourceType(); MemRefType subViewType = op.getType(); // The base memref and the view memref should be in the same memory space. @@ -3169,8 +3189,7 @@ // Verify result type against inferred type. auto expectedType = SubTensorOp::inferResultType( - op.getSourceRankedTensorType(), - extractFromI64ArrayAttr(op.static_offsets()), + op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); if (op.getType() != expectedType) @@ -3186,6 +3205,80 @@ context); } +//===----------------------------------------------------------------------===// +// SubTensorInsertOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, SubTensorInsertOp op) { + return printOpWithOffsetsSizesAndStrides( + p, op, + [](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest(); }, + /*resultTypeKeyword=*/"into"); +} + +static ParseResult parseSubTensorInsertOp(OpAsmParser &parser, + OperationState &result) { + return parseOpWithOffsetsSizesAndStrides( + parser, result, + [](OpAsmParser &parser, OpAsmParser::OperandType &dstInfo) { + return failure(parser.parseKeyword("into") || + parser.parseOperand(dstInfo)); + }, + "into"); +} + +void mlir::SubTensorInsertOp::build( + OpBuilder &b, OperationState &result, Value source, Value dest, + ArrayRef staticOffsets, ArrayRef staticSizes, + ArrayRef staticStrides, ValueRange offsets, ValueRange sizes, + ValueRange strides, ArrayRef attrs) { + build(b, result, dest.getType(), source, dest, offsets, sizes, strides, + b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), + b.getI64ArrayAttr(staticStrides)); + result.addAttributes(attrs); +} + +/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes` +/// and `staticStrides` are automatically filled with source-memref-rank +/// sentinel values that encode dynamic entries. +void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result, + Value source, Value dest, + ValueRange offsets, ValueRange sizes, + ValueRange strides, + ArrayRef attrs) { + auto sourceRankedTensorType = source.getType().cast(); + unsigned rank = sourceRankedTensorType.getRank(); + SmallVector staticOffsetsVector; + staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + SmallVector staticSizesVector; + staticSizesVector.assign(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector; + staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + build(b, result, source, dest, staticOffsetsVector, staticSizesVector, + staticStridesVector, offsets, sizes, strides, attrs); +} + +SmallVector SubTensorInsertOp::getOrCreateRanges(OpBuilder &b, + Location loc) { + return ::getOrCreateRangesImpl(*this, b, loc); +} + +/// Verifier for SubViewOp. +static LogicalResult verify(SubTensorInsertOp op) { + if (failed(verifyOpWithOffsetSizesAndStrides(op))) + return failure(); + if (op.getType() != op.dest().getType()) + return op.emitError("expected result type to be ") << op.dest().getType(); + return success(); +} + +// void SubTensorInsertOp::getCanonicalizationPatterns( +// OwningRewritePatternList &results, MLIRContext *context) { +// results.insert< +// OpWithOffsetSizesAndStridesConstantArgumentFolder>( +// context); +// } + //===----------------------------------------------------------------------===// // TensorCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -872,7 +872,6 @@ return } - // CHECK-LABEL: func @subtensor({{.*}}) { func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) { %c0 = constant 0 : index @@ -890,3 +889,16 @@ return } + +// CHECK-LABEL: func @subtensor_insert({{.*}}) { +func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // CHECK: subtensor_insert + // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> + %1 = subtensor_insert %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1] + : tensor<8x16x4xf32> into tensor<16x32x8xf32> + + return +}