diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -32,6 +32,88 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> { + let summary = "operation to define a tensor of particular value"; + + let description = [{ + `linalg.init_tensor` is an operation that materializes a tensor of + a given shape. The shape could be dynamic or static. + }]; + + let arguments = + (ins Variadic:$sizes, I64ArrayAttr:$static_sizes); + + let results = (outs AnyTensor:$result); + + let verifier = [{ return ::verify(*this); }]; + + let extraClassDeclaration = [{ + static StringRef getStaticSizesAttrName() { + return "static_sizes"; + } + + RankedTensorType getType() { + return getResult().getType().cast(); } + + // Infer the shape of the result tensor given the static shapes + // and element type of the result tensor. + static Type inferResultType(ArrayRef staticSizes, Type elementType); + + // Return true if the size of the tensor is dynamic at `idx` + bool isDynamicSize(unsigned idx) { + APInt v = *(static_sizes().getAsValueRange().begin() + idx); + return ShapedType::isDynamic(v.getSExtValue()); + } + + // Assert that the size of the result tensor is static at `idx` + // and return the shape. + int64_t getStaticSize(unsigned idx) { + assert(!isDynamicSize(idx) && "expected static size"); + APInt v = *(static_sizes(). + template getAsValueRange().begin() + idx); + return v.getSExtValue(); + } + + // Return the argument position that contains the dynamic size of + // the tensor at dimension `idx`. Asserts that the shape is + // dynamic at that `idx`. + unsigned getIndexOfDynamicSize(unsigned idx) { + assert(isDynamicSize(idx) && "expected dynamic size"); + return std::count_if( + static_sizes().getValue().begin(), + static_sizes().getValue().begin() + idx, + [&](Attribute attr) { + return ShapedType::isDynamic(attr.cast().getInt()); + }); + } + + // Return the Value of the dynamic size of the tensor at dimension + // `idx`. Asserts that the shape is dynamic at that `idx. + Value getDynamicSize(unsigned idx) { + return getOperand(getIndexOfDynamicSize(idx)); + } + }]; + + let builders = [ + OpBuilderDAG<(ins "ValueRange":$shape, "Type":$elementType), + [{ + SmallVector staticShape( + shape.size(), ShapedType::kDynamicSize); + build($_builder, $_state, + InitTensorOp::inferResultType(staticShape, elementType), + shape, $_builder.getI64ArrayAttr(staticShape)); + }]>, + OpBuilderDAG<(ins "ArrayRef":$shape, "Type":$elementType), + [{ + build($_builder, $_state, + InitTensorOp::inferResultType(shape, elementType), + ValueRange{}, $_builder.getI64ArrayAttr(shape)); + }]> + ]; + + let hasCanonicalizer = 1; +} + def Linalg_RangeOp : Linalg_Op<"range", [NoSideEffect]>, Arguments<(ins Index:$min, Index:$max, Index:$step)>, diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -35,6 +35,14 @@ #include "mlir/Interfaces/ViewLikeInterface.h.inc" namespace mlir { +/// Print a list with either (1) the static integer value in `arrayAttr` if +/// `isDynamic` evaluates to false or (2) the next value otherwise. +/// This allows idiomatic printing of mixed value and integer attributes in a +/// list. E.g. `[%arg0, 7, 42, %arg42]`. +void printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, + ArrayAttr arrayAttr, + llvm::function_ref isDynamic); + /// Print part of an op of the form: /// ``` /// `[` offset-list `]` @@ -48,6 +56,19 @@ ArrayRef elidedAttrs = OffsetSizeAndStrideOpInterface::getSpecialAttrNames()); +/// Parse a mixed list with either (1) static integer values or (2) SSA values. +/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` +/// encode the position of SSA values. Add the parsed SSA values to `ssa` +/// in-order. +// +/// E.g. after parsing "[%arg0, 7, 42, %arg42]": +/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg1]". +ParseResult +parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, + StringRef attrName, int64_t dynVal, + SmallVectorImpl &ssa); + /// Parse trailing part of an op of the form: /// ``` /// `[` offset-list `]` @@ -87,6 +108,12 @@ llvm::function_ref parseOptionalStridePrefix = nullptr); +/// Verify that a the `values` has as many elements as the number of entries in +/// `attr` for which `isDynamic` evaluates to true. +LogicalResult verifyListOfOperandsOrIntegers( + Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, + ValueRange values, llvm::function_ref isDynamic); + } // namespace mlir #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -550,6 +550,145 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } +//===----------------------------------------------------------------------===// +// InitTensorOp +//===----------------------------------------------------------------------===// + +static ParseResult parseInitTensorOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType srcInfo; + Type dstType; + SmallVector sizeInfo; + IndexType indexType = parser.getBuilder().getIndexType(); + if (failed(parseListOfOperandsOrIntegers( + parser, result, InitTensorOp::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizeInfo)) || + failed(parser.parseOptionalAttrDict(result.attributes)) || + failed(parser.parseColonType(dstType)) || + failed(parser.resolveOperands(sizeInfo, indexType, result.operands))) + return failure(); + return parser.addTypeToList(dstType, result.types); +} + +static void print(OpAsmPrinter &p, InitTensorOp op) { + p << op.getOperation()->getName() << ' '; + printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + p.printOptionalAttrDict(op.getAttrs(), + InitTensorOp::getStaticSizesAttrName()); + p << " : " << op.getType(); +} + +static LogicalResult verify(InitTensorOp op) { + RankedTensorType resultType = op.getType(); + SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( + op.static_sizes().cast(), + [](Attribute a) -> int64_t { return a.cast().getInt(); })); + + if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(), + op.static_sizes(), op.sizes(), + ShapedType::isDynamic))) + return failure(); + + Type expectedType = + InitTensorOp::inferResultType(staticSizes, resultType.getElementType()); + if (resultType != expectedType) { + return op.emitError("specified type ") + << resultType << " does not match the inferred type " + << expectedType; + } + return success(); +} + +Type InitTensorOp::inferResultType(ArrayRef staticSizes, + Type elementType) { + return RankedTensorType::get(staticSizes, elementType); +} + +namespace { +/// Change the type of the result of a `linalg.init_tensor` by making the result +/// type statically sized along dimension that in the original operation where +/// defined as dynamic, but the size was defined using a `constant` op. For +/// example +/// +/// %c5 = constant 5: index +/// %0 = linalg.init_tensor [%arg0, %c5] : tensor +/// +/// to +/// +/// %0 = linalg.init_tensor [%arg0, 5] : tensor +struct ReplaceStaticShapeDims : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InitTensorOp op, + PatternRewriter &rewriter) const override { + SmallVector dynamicSizes; + SmallVector staticSizes; + for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { + // If the size is already static, nothing to do. + if (!op.isDynamicSize(i)) { + staticSizes.push_back(op.getStaticSize(i)); + continue; + } + + // If the size is dynamic but defined using a `constant` op, get the + // constant value to find the static size to use. + unsigned operandNum = op.getIndexOfDynamicSize(i); + Value sizeOperand = op.getOperand(operandNum); + if (auto constantIndexOp = sizeOperand.getDefiningOp()) { + staticSizes.push_back(constantIndexOp.getValue()); + continue; + } + + // Fallback case. Keep the size dynamic. + dynamicSizes.push_back(sizeOperand); + staticSizes.push_back(ShapedType::kDynamicSize); + } + RankedTensorType newType = + RankedTensorType::get(staticSizes, op.getType().getElementType()); + if (newType == op.getType()) + return failure(); + auto newOp = + rewriter.create(op.getLoc(), newType, dynamicSizes, + rewriter.getI64ArrayAttr(staticSizes)); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + return success(); + } +}; + +/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim` +/// with +/// - A constant value if the size is static along the dimension. +/// - The dynamic value that defines the size of the result of +/// `linalg.init_tensor` op. +struct ReplaceDimOfInitTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp(); + if (!initTensorOp) + return failure(); + auto dimIndex = dimOp.index().getDefiningOp(); + if (!dimIndex) + return failure(); + int64_t index = dimIndex.getValue(); + if (!initTensorOp.isDynamicSize(index)) { + rewriter.replaceOpWithNewOp( + dimOp, initTensorOp.getStaticSize(index)); + } else { + rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index)); + } + return success(); + } +}; +} // namespace + +void InitTensorOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1472,11 +1472,29 @@ return success(); } }; + +/// Fold dim of a dim of a cast into the the dim of the source of the tensor +/// cast. +template +struct DimOfCastOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto castOp = dimOp.memrefOrTensor().getDefiningOp(); + if (!castOp) + return failure(); + Value newSource = castOp.getOperand(); + rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.index()); + return success(); + } +}; + } // end anonymous namespace. void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); } // --------------------------------------------------------------------------- diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -17,54 +17,43 @@ /// Include the definitions of the loop-like interfaces. #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" -static LogicalResult verifyOpWithOffsetSizesAndStridesPart( - OffsetSizeAndStrideOpInterface op, StringRef name, - unsigned expectedNumElements, StringRef attrName, ArrayAttr attr, - llvm::function_ref isDynamic, ValueRange values) { +LogicalResult mlir::verifyListOfOperandsOrIntegers( + Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, + ValueRange values, llvm::function_ref isDynamic) { /// Check static and dynamic offsets/sizes/strides breakdown. if (attr.size() != expectedNumElements) - return op.emitError("expected ") + return op->emitError("expected ") << expectedNumElements << " " << name << " values"; unsigned expectedNumDynamicEntries = llvm::count_if(attr.getValue(), [&](Attribute attr) { return isDynamic(attr.cast().getInt()); }); if (values.size() != expectedNumDynamicEntries) - return op.emitError("expected ") + return op->emitError("expected ") << expectedNumDynamicEntries << " dynamic " << name << " values"; return success(); } LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) { std::array ranks = op.getArrayAttrRanks(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", ranks[0], - OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), - op.static_offsets(), ShapedType::isDynamicStrideOrOffset, - op.offsets()))) + if (failed(verifyListOfOperandsOrIntegers( + op, "offset", ranks[0], op.static_offsets(), op.offsets(), + ShapedType::isDynamicStrideOrOffset))) return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", ranks[1], - OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), - op.static_sizes(), ShapedType::isDynamic, op.sizes()))) + if (failed(verifyListOfOperandsOrIntegers(op, "size", ranks[1], + op.static_sizes(), op.sizes(), + ShapedType::isDynamic))) return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", ranks[2], - OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), - op.static_strides(), ShapedType::isDynamicStrideOrOffset, - op.strides()))) + if (failed(verifyListOfOperandsOrIntegers( + op, "stride", ranks[2], op.static_strides(), op.strides(), + ShapedType::isDynamicStrideOrOffset))) return failure(); return success(); } -/// Print a list with either (1) the static integer value in `arrayAttr` if -/// `isDynamic` evaluates to false or (2) the next value otherwise. -/// This allows idiomatic printing of mixed value and integer attributes in a -/// list. E.g. `[%arg0, 7, 42, %arg42]`. -static void -printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, - ArrayAttr arrayAttr, - llvm::function_ref isDynamic) { +void mlir::printListOfOperandsOrIntegers( + OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, + llvm::function_ref isDynamic) { p << '['; unsigned idx = 0; llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { @@ -95,18 +84,9 @@ p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); } -/// Parse a mixed list with either (1) static integer values or (2) SSA values. -/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` -/// encode the position of SSA values. Add the parsed SSA values to `ssa` -/// in-order. -// -/// E.g. after parsing "[%arg0, 7, 42, %arg42]": -/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" -/// 2. `ssa` is filled with "[%arg0, %arg1]". -static ParseResult -parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, - StringRef attrName, int64_t dynVal, - SmallVectorImpl &ssa) { +ParseResult mlir::parseListOfOperandsOrIntegers( + OpAsmParser &parser, OperationState &result, StringRef attrName, + int64_t dynVal, SmallVectorImpl &ssa) { if (failed(parser.parseLSquare())) return failure(); // 0-D. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -351,3 +351,28 @@ outs(%b : memref) return } + +// ----- + +func @init_tensor_static_dim() -> (index) { + %c2 = constant 2 : index + %c6 = constant 6 : index + %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32> + %1 = dim %0, %c2 : tensor<4x5x?xf32> + return %1 : index +} +// CHECK: func @init_tensor_static_dim +// CHECK-DAG: %[[C6:.+]] = constant 6 : index +// CHECK: return %[[C6]] + +// ----- + +func @init_tensor_dynamic_dim(%arg0 : index) -> (index) { + %c2 = constant 2 : index + %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32> + %1 = dim %0, %c2 : tensor<4x5x?xf32> + return %1 : index +} +// CHECK: func @init_tensor_dynamic_dim +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG0]] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt -split-input-file %s | FileCheck %s -// | mlir-opt | FileCheck %s +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s // TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered. // @@ -698,3 +697,42 @@ // CHECK-LABEL: func @memref_reshape_zero_dim // CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref // CHECK: linalg.reshape %{{.*}} [] : memref into memref<1x1xf32> + +// ----- + +func @init_tensor(%arg0 : index, %arg1 : index) +{ + %0 = linalg.init_tensor [3, 42] : tensor<3x42xf32> + %1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32> + return +} +// CHECK-LABEL: func @init_tensor +// CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32> +// CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32> + +// ----- + +func @init_tensor_err(%arg0 : index, %arg1 : index) +{ + // expected-error @+1 {{specified type 'tensor<4x?x?x5xf32>' does not match the inferred type 'tensor<4x5x?x?xf32>'}} + %1 = linalg.init_tensor [4, 5, %arg0, %arg1] : tensor<4x?x?x5xf32> + return +} + +// ----- + +func @init_tensor_err(%arg0 : index) +{ + // expected-error @+1 {{expected 4 sizes values}} + %1 = linalg.init_tensor [4, 5, %arg0] : tensor<4x?x?x5xf32> + return +} + +// ----- + +func @init_tensor_err(%arg0 : index) +{ + // expected-error @+1 {{expected 2 dynamic sizes values}} + %1 = "linalg.init_tensor"(%arg0) {static_sizes = [4, -1, -1, 5]} : (index) -> tensor<4x?x?x5xf32> + return +}