diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -61,21 +61,26 @@ } def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>, - Arguments<(ins AnyStridedMemRef:$view, AffineMapArrayAttr:$reassociation)>, - Results<(outs AnyStridedMemRef)> { - let summary = "linalg.reshape produces a new view into the operand view"; + Arguments<(ins AnyStridedMemRefOrRankedTensor:$src, + AffineMapArrayAttr:$reassociation)>, + Results<(outs AnyStridedMemRefOrRankedTensor)> { + let summary = "linalg.reshape produces a new reshaped tensor (resp. view)"; let description = [{ - The `linalg.reshape` op produces a new view whose sizes are a reassociation - of the original `view`. Depending on whether or not the reassociated - MemRefType is contiguous, the resulting memref may require explicit alloc - and copies. + The `linalg.reshape` op produces a new tensor (resp. view) whose sizes are a + reassociation of the original `src`. A reassociation is defined as a continuous grouping of dimensions and is represented with an affine map array attribute. In the future, non-continuous groupings may be allowed (i.e. permutations, reindexings etc). - For now, it is assumed that either: + In the tensor case, a reshape always produces a new tensor. + + In the strided memref case, depending on whether or not the reassociated + MemRefType is contiguous, the resulting memref may require explicit alloc + and copies. + + For now, in the strided memref case, it is assumed that either: 1. a reassociation produces and consumes contiguous MemRefType or, 2. the reshape op will be folded into its consumers (by changing the shape of the computations). @@ -92,12 +97,16 @@ ```mlir // Dimension collapse (i, j) -> i' and k -> k' + %b = linalg.reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + tensor into tensor %1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] : memref into memref ``` ```mlir // Dimension expansion i -> (i', j') and (k) -> (k') + %b = linalg.reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + tensor into tensor %1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] : memref into memref ``` @@ -105,22 +114,22 @@ let builders = [ // Builder for a contracting reshape whose result type is computed from - // `view` and `reassociation`. - OpBuilder<"Builder *b, OperationState &result, Value view, " + // `src` and `reassociation`. + OpBuilder<"Builder *b, OperationState &result, Value src, " "ArrayRef> reassociation, " "ArrayRef attrs = {}">, // Builder for a reshape whose result type is passed explicitly. This may be // either a contracting or expanding reshape. - OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view," + OpBuilder<"Builder *b, OperationState &result, Type resultType, Value src," "ArrayRef> reassociation, " "ArrayRef attrs = {}">]; let extraClassDeclaration = [{ static StringRef getReassociationAttrName() { return "reassociation"; } - MemRefType getViewType() { return view().getType().cast(); } + ShapedType getSrcShapedType() {return src().getType().cast();} }]; let assemblyFormat = [{ - $view $reassociation attr-dict `:` type($view) `into` type(results) + $src $reassociation attr-dict `:` type($src) `into` type(results) }]; let hasFolder = 1; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -648,6 +648,9 @@ MemRefRankOf<[AnyType], [rank]>.predicate]>, AnyStridedMemRef.description # " of rank " # rank>; +def AnyStridedMemRefOrRankedTensor: + AnyTypeOf<[AnyStridedMemRef, AnyRankedTensor]>; + // This represents a generic tuple without any constraints on element type. def AnyTuple : Type; diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -164,9 +164,9 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto reshapeOp = cast(op); - MemRefType dstType = reshapeOp.getResult().getType().cast(); + MemRefType dstType = reshapeOp.getResult().getType().dyn_cast(); - if (!dstType.hasStaticShape()) + if (!dstType || !dstType.hasStaticShape()) return failure(); int64_t offset; @@ -179,7 +179,7 @@ edsc::ScopedContext context(rewriter, op->getLoc()); ReshapeOpOperandAdaptor adaptor(operands); - BaseViewConversionHelper baseDesc(adaptor.view()); + BaseViewConversionHelper baseDesc(adaptor.src()); BaseViewConversionHelper desc(typeConverter.convertType(dstType)); desc.setAllocatedPtr(baseDesc.allocatedPtr()); desc.setAlignedPtr(baseDesc.alignedPtr()); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -495,6 +495,35 @@ MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); } +/// Compute the RankedTensorType obtained by applying `reassociation` to `type`. +static RankedTensorType +computeReshapeCollapsedType(RankedTensorType type, + ArrayRef reassociation) { + auto shape = type.getShape(); + SmallVector newShape; + newShape.reserve(reassociation.size()); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + unsigned currentDim = 0; + for (AffineMap m : reassociation) { + unsigned dim = m.getNumResults(); + auto band = shape.drop_front(currentDim).take_front(dim); + int64_t size = 1; + if (llvm::any_of(band, + [](int64_t v) { return v == ShapedType::kDynamicSize; })) + size = ShapedType::kDynamicSize; + else + for (unsigned d = 0; d < dim; ++d) + size *= shape[currentDim + d]; + newShape.push_back(size); + currentDim += dim; + } + + return RankedTensorType::get(newShape, type.getElementType()); +} + /// Helper functions assert Attribute of the proper type in attr and returns the /// corresponding vector. /// TODO(rridle,ntv) this should be evolved into a generic @@ -531,30 +560,34 @@ } void mlir::linalg::ReshapeOp::build( - Builder *b, OperationState &result, Value view, + Builder *b, OperationState &result, Value src, ArrayRef> reassociation, ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); - auto memRefType = view.getType().cast(); - auto resultType = computeReshapeCollapsedType(memRefType, maps); - build(b, result, resultType, view, attrs); + Type resultType; + if (auto memRefType = src.getType().dyn_cast()) + resultType = computeReshapeCollapsedType(memRefType, maps); + else + resultType = computeReshapeCollapsedType( + src.getType().cast(), maps); + build(b, result, resultType, src, attrs); result.addAttribute(ReshapeOp::getReassociationAttrName(), b->getAffineMapArrayAttr(maps)); } void mlir::linalg::ReshapeOp::build( - Builder *b, OperationState &result, Type resultType, Value view, + Builder *b, OperationState &result, Type resultType, Value src, ArrayRef> reassociation, ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); - build(b, result, resultType, view, attrs); + build(b, result, resultType, src, attrs); result.addAttribute(ReshapeOp::getReassociationAttrName(), b->getAffineMapArrayAttr(maps)); } static LogicalResult verify(ReshapeOp op) { - MemRefType expandedType = op.getViewType(); - MemRefType collapsedType = op.getResult().getType().cast(); + ShapedType expandedType = op.getSrcShapedType(); + ShapedType collapsedType = op.getResult().getType().cast(); unsigned expandedRank = expandedType.getRank(); unsigned collapsedRank = collapsedType.getRank(); bool isCollapse = expandedRank > collapsedRank; @@ -568,7 +601,7 @@ return op.emitOpError("expected to collapse or expand dims"); if (collapsedRank != op.reassociation().size()) - return op.emitOpError("expected rank of the collapsed view(") + return op.emitOpError("expected rank of the collapsed type(") << collapsedRank << ") to be the number of reassociation maps(" << op.reassociation().size() << ")"; auto maps = getAffineMaps(op.reassociation()); @@ -581,10 +614,27 @@ if (!isReassociationValid(maps, &invalidIdx)) return op.emitOpError("expected reassociation map #") << invalidIdx << " to be valid and contiguous"; - MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); - if (collapsedType != expectedType) - return op.emitOpError("expected collapsed type to be ") - << expectedType << ", but got " << collapsedType; + if ((expandedType.isa() && !collapsedType.isa()) || + (expandedType.isa() && + !collapsedType.isa())) + return op.emitOpError("expected source and result of same shaped type"); + + if (auto expandedMemRefType = expandedType.dyn_cast()) { + MemRefType expectedType = + computeReshapeCollapsedType(expandedMemRefType, maps); + if (collapsedType != expectedType) + return op.emitOpError("expected collapsed type to be ") + << expectedType << ", but got " << collapsedType; + } else { + // TODO(ntv): expanding a ? with a non-constant is under-specified. Error + // out. + auto expandedTensorType = expandedType.dyn_cast(); + RankedTensorType expectedType = + computeReshapeCollapsedType(expandedTensorType, maps); + if (collapsedType != expectedType) + return op.emitOpError("expected collapsed type to be ") + << expectedType << ", but got " << collapsedType; + } return success(); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -485,7 +485,7 @@ // ----- func @reshape(%arg0: memref) { - // expected-error @+1 {{expected rank of the collapsed view(2) to be the number of reassociation maps(1)}} + // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}} %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>] : memref into memref } @@ -508,6 +508,14 @@ // ----- +func @reshape(%arg0: tensor) { + // expected-error @+1 {{expected source and result of same shaped type}} + %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : + tensor into memref +} + +// ----- + func @reshape(%arg0: memref) { // expected-error @+1 {{expected collapsed type to be 'memref', but got 'memref (d0 * s0 + d1)>>'}} %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -505,8 +505,8 @@ // CHECK-DAG: #[[reshape5D2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)> // CHECK-DAG: #[[reshape5D34:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> -func @reshape_static(%arg0: memref<3x4x5xf32>) { - // Reshapes that collapse and expand back a contiguous tensor. +func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>) { + // Reshapes that collapse and expand back a contiguous buffer. %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref<3x4x5xf32> into memref<12x5xf32> @@ -523,7 +523,7 @@ memref<3x4x5xf32> into memref<60xf32> %r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j, k)>] : memref<60xf32> into memref<3x4x5xf32> - // Reshapes that expand and collapse back a contiguous tensor with some 1's. + // Reshapes that expand and collapse back a contiguous buffer with some 1's. %3 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>, affine_map<(i, j, k, l, m) -> (k)>, affine_map<(i, j, k, l, m) -> (l, m)>] : @@ -532,6 +532,23 @@ affine_map<(i, j, k, l, m) -> (k)>, affine_map<(i, j, k, l, m) -> (l, m)>] : memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + // Reshapes on tensors. + %t0 = linalg.reshape %arg1 [affine_map<(i, j, k, l, m) -> (i, j)>, + affine_map<(i, j, k, l, m) -> (k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : + tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> + %rt0 = linalg.reshape %t0 [affine_map<(i, j, k, l, m) -> (i, j)>, + affine_map<(i, j, k, l, m) -> (k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : + tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> + %t1 = linalg.reshape %arg2 [affine_map<(i, j, k, l, m) -> (i, j)>, + affine_map<(i, j, k, l, m) -> (k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : + tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> + %rt1 = linalg.reshape %t1 [affine_map<(i, j, k, l, m) -> (i)>, + affine_map<(i, j, k, l, m) -> (j, k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : + tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> return } // CHECK-LABEL: func @reshape_static @@ -551,6 +568,11 @@ // CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> // CHECK: linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]] // CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> +// +// CHECK: linalg.reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> +// CHECK: linalg.reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> +// CHECK: linalg.reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> +// CHECK: linalg.reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> // -----