diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -60,9 +60,31 @@ let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)"; } -def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>, - Arguments<(ins AnyStridedMemRef:$view, AffineMapArrayAttr:$reassociation)>, - Results<(outs AnyStridedMemRef)> { +class Linalg_ReshapeLikeOp : + Linalg_Op { + let builders = [ + // Builder for a contracting reshape whose result type is computed from + // `src` and `reassociation`. + OpBuilder<"Builder *b, OperationState &result, Value src, " + "ArrayRef> reassociation, " + "ArrayRef attrs = {}">, + // Builder for a reshape whose result type is passed explicitly. This may be + // either a contracting or expanding reshape. + OpBuilder<"Builder *b, OperationState &result, Type resultType, Value src," + "ArrayRef> reassociation, " + "ArrayRef attrs = {}">]; + + code commonExtraClassDeclaration = [{ + static StringRef getReassociationAttrName() { return "reassociation"; } + }]; + let assemblyFormat = [{ + $src $reassociation attr-dict `:` type($src) `into` type(results) + }]; +} + +def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">, + Arguments<(ins AnyStridedMemRef:$src, AffineMapArrayAttr:$reassociation)>, + Results<(outs AnyStridedMemRef:$result)> { let summary = "linalg.reshape produces a new view into the operand view"; let description = [{ The `linalg.reshape` op produces a new view whose sizes are a reassociation @@ -102,27 +124,55 @@ memref into memref ``` }]; + let extraClassDeclaration = commonExtraClassDeclaration # [{ + MemRefType getSrcType() { return src().getType().cast(); } + MemRefType getResultType() { return result().getType().cast(); } + }]; + let hasFolder = 1; +} - let builders = [ - // Builder for a contracting reshape whose result type is computed from - // `view` and `reassociation`. - OpBuilder<"Builder *b, OperationState &result, Value view, " - "ArrayRef> reassociation, " - "ArrayRef attrs = {}">, - // Builder for a reshape whose result type is passed explicitly. This may be - // either a contracting or expanding reshape. - OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view," - "ArrayRef> reassociation, " - "ArrayRef attrs = {}">]; +def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">, + Arguments<(ins AnyTensor:$src, + AffineMapArrayAttr:$reassociation)>, + Results<(outs AnyTensor:$result)> { + let summary = "linalg.tensor_reshape produces a new reshaped tensor."; + let description = [{ + The `linalg.reshape` op produces a new tensor whose sizes are a + reassociation of the original `src`. - let extraClassDeclaration = [{ - static StringRef getReassociationAttrName() { return "reassociation"; } - MemRefType getViewType() { return view().getType().cast(); } + A reassociation is defined as a continuous grouping of dimensions and is + represented with an affine map array attribute. In the future, + non-continuous groupings may be allowed (i.e. permutations, reindexings + etc). + + A reshape may either collapse or expand dimensions, depending on the + relationship between source and target tensor ranks. The verification rule + is that the reassociation maps are applied to the tensor with the larger + rank to obtain the tensor with the smaller rank. In the case of a dimension + expansion, the reassociation maps can be interpreted as inverse maps. + + Examples: + + ```mlir + // Dimension collapse (i, j) -> i' and k -> k' + %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + tensor into tensor + ``` + + ```mlir + // Dimension expansion i -> (i', j') and (k) -> (k') + %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + tensor into tensor + ``` }]; - let assemblyFormat = [{ - $view $reassociation attr-dict `:` type($view) `into` type(results) + let extraClassDeclaration = commonExtraClassDeclaration # [{ + RankedTensorType getSrcType() { + return src().getType().cast(); + } + RankedTensorType getResultType() { + return result().getType().cast(); + } }]; - let hasFolder = 1; } def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -164,7 +164,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto reshapeOp = cast(op); - MemRefType dstType = reshapeOp.getResult().getType().cast(); + MemRefType dstType = reshapeOp.getResultType(); if (!dstType.hasStaticShape()) return failure(); @@ -179,7 +179,7 @@ edsc::ScopedContext context(rewriter, op->getLoc()); ReshapeOpOperandAdaptor adaptor(operands); - BaseViewConversionHelper baseDesc(adaptor.view()); + BaseViewConversionHelper baseDesc(adaptor.src()); BaseViewConversionHelper desc(typeConverter.convertType(dstType)); desc.setAllocatedPtr(baseDesc.allocatedPtr()); desc.setAlignedPtr(baseDesc.alignedPtr()); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -531,30 +531,33 @@ } void mlir::linalg::ReshapeOp::build( - Builder *b, OperationState &result, Value view, + Builder *b, OperationState &result, Value src, ArrayRef> reassociation, ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); - auto memRefType = view.getType().cast(); + auto memRefType = src.getType().cast(); auto resultType = computeReshapeCollapsedType(memRefType, maps); - build(b, result, resultType, view, attrs); + build(b, result, resultType, src, attrs); result.addAttribute(ReshapeOp::getReassociationAttrName(), b->getAffineMapArrayAttr(maps)); } void mlir::linalg::ReshapeOp::build( - Builder *b, OperationState &result, Type resultType, Value view, + Builder *b, OperationState &result, Type resultType, Value src, ArrayRef> reassociation, ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); - build(b, result, resultType, view, attrs); + build(b, result, resultType, src, attrs); result.addAttribute(ReshapeOp::getReassociationAttrName(), b->getAffineMapArrayAttr(maps)); } -static LogicalResult verify(ReshapeOp op) { - MemRefType expandedType = op.getViewType(); - MemRefType collapsedType = op.getResult().getType().cast(); +// Common verifier for reshape-like types. Fills `expandedType` and +// `collapsedType` with the proper `src` or `result` type. +template +LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType, T &collapsedType) { + expandedType = op.getSrcType(); + collapsedType = op.getResultType(); unsigned expandedRank = expandedType.getRank(); unsigned collapsedRank = collapsedType.getRank(); bool isCollapse = expandedRank > collapsedRank; @@ -568,7 +571,7 @@ return op.emitOpError("expected to collapse or expand dims"); if (collapsedRank != op.reassociation().size()) - return op.emitOpError("expected rank of the collapsed view(") + return op.emitOpError("expected rank of the collapsed type(") << collapsedRank << ") to be the number of reassociation maps(" << op.reassociation().size() << ")"; auto maps = getAffineMaps(op.reassociation()); @@ -581,6 +584,14 @@ if (!isReassociationValid(maps, &invalidIdx)) return op.emitOpError("expected reassociation map #") << invalidIdx << " to be valid and contiguous"; + return success(); +} + +static LogicalResult verify(ReshapeOp op) { + MemRefType expandedType, collapsedType; + if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) + return failure(); + auto maps = getAffineMaps(op.reassociation()); MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); if (collapsedType != expectedType) return op.emitOpError("expected collapsed type to be ") @@ -589,6 +600,75 @@ } //===----------------------------------------------------------------------===// +// TensorReshapeOp +//===----------------------------------------------------------------------===// + +/// Compute the RankedTensorType obtained by applying `reassociation` to `type`. +static RankedTensorType +computeTensorReshapeCollapsedType(RankedTensorType type, + ArrayRef reassociation) { + auto shape = type.getShape(); + SmallVector newShape; + newShape.reserve(reassociation.size()); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + unsigned currentDim = 0; + for (AffineMap m : reassociation) { + unsigned dim = m.getNumResults(); + auto band = shape.drop_front(currentDim).take_front(dim); + int64_t size = 1; + if (llvm::is_contained(band, ShapedType::kDynamicSize)) + size = ShapedType::kDynamicSize; + else + for (unsigned d = 0; d < dim; ++d) + size *= shape[currentDim + d]; + newShape.push_back(size); + currentDim += dim; + } + + return RankedTensorType::get(newShape, type.getElementType()); +} + +void mlir::linalg::TensorReshapeOp::build( + Builder *b, OperationState &result, Value src, + ArrayRef> reassociation, + ArrayRef attrs) { + auto maps = getSymbolLessAffineMaps(reassociation); + auto resultType = computeTensorReshapeCollapsedType( + src.getType().cast(), maps); + build(b, result, resultType, src, attrs); + result.addAttribute(TensorReshapeOp::getReassociationAttrName(), + b->getAffineMapArrayAttr(maps)); +} + +void mlir::linalg::TensorReshapeOp::build( + Builder *b, OperationState &result, Type resultType, Value src, + ArrayRef> reassociation, + ArrayRef attrs) { + auto maps = getSymbolLessAffineMaps(reassociation); + build(b, result, resultType, src, attrs); + result.addAttribute(TensorReshapeOp::getReassociationAttrName(), + b->getAffineMapArrayAttr(maps)); +} + +static LogicalResult verify(TensorReshapeOp op) { + RankedTensorType expandedType, collapsedType; + if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) + return failure(); + auto maps = getAffineMaps(op.reassociation()); + // TODO(ntv): expanding a ? with a non-constant is under-specified. Error + // out. + RankedTensorType expectedType = + computeTensorReshapeCollapsedType(expandedType, maps); + if (collapsedType != expectedType) + return op.emitOpError("expected collapsed type to be ") + << expectedType << ", but got " << collapsedType; + return success(); +} + +//===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -485,7 +485,7 @@ // ----- func @reshape(%arg0: memref) { - // expected-error @+1 {{expected rank of the collapsed view(2) to be the number of reassociation maps(1)}} + // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}} %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>] : memref into memref } diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -505,8 +505,8 @@ // CHECK-DAG: #[[reshape5D2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)> // CHECK-DAG: #[[reshape5D34:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> -func @reshape_static(%arg0: memref<3x4x5xf32>) { - // Reshapes that collapse and expand back a contiguous tensor. +func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>) { + // Reshapes that collapse and expand back a contiguous buffer. %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref<3x4x5xf32> into memref<12x5xf32> @@ -523,7 +523,7 @@ memref<3x4x5xf32> into memref<60xf32> %r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j, k)>] : memref<60xf32> into memref<3x4x5xf32> - // Reshapes that expand and collapse back a contiguous tensor with some 1's. + // Reshapes that expand and collapse back a contiguous buffer with some 1's. %3 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>, affine_map<(i, j, k, l, m) -> (k)>, affine_map<(i, j, k, l, m) -> (l, m)>] : @@ -532,6 +532,23 @@ affine_map<(i, j, k, l, m) -> (k)>, affine_map<(i, j, k, l, m) -> (l, m)>] : memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + // Reshapes on tensors. + %t0 = linalg.tensor_reshape %arg1 [affine_map<(i, j, k, l, m) -> (i, j)>, + affine_map<(i, j, k, l, m) -> (k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : + tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> + %rt0 = linalg.tensor_reshape %t0 [affine_map<(i, j, k, l, m) -> (i, j)>, + affine_map<(i, j, k, l, m) -> (k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : + tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> + %t1 = linalg.tensor_reshape %arg2 [affine_map<(i, j, k, l, m) -> (i, j)>, + affine_map<(i, j, k, l, m) -> (k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : + tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> + %rt1 = linalg.tensor_reshape %t1 [affine_map<(i, j, k, l, m) -> (i)>, + affine_map<(i, j, k, l, m) -> (j, k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : + tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> return } // CHECK-LABEL: func @reshape_static @@ -551,6 +568,11 @@ // CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> // CHECK: linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]] // CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> +// +// CHECK: linalg.tensor_reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> +// CHECK: linalg.tensor_reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> +// CHECK: linalg.tensor_reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> +// CHECK: linalg.tensor_reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> // -----