diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -58,6 +58,43 @@ let verifier = ?; } +def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>, + Arguments<(ins AnyStridedMemRef:$view, AffineMapArrayAttr:$reassociation)>, + Results<(outs AnyStridedMemRef)> { + let summary = "linalg.reshape produces a new strided memref"; + let description = [{ + The `linalg.reshape` op produces a strided memref whose sizes are a + reassociation of the original `view`. Depending on whether or not the + reassociated dimensions are contiguous, the resulting memref may require + explicit alloc and copies. + + A reassociation is defined as a contiguous grouping of dimensions and is + represented with a affine map array attribute. In the future, non-contiguous + groupings may be allowed (i.e. permutations, reindexings etc). + + For now, it is assumed that all reassociation occur on contiguous + dimensions or that the reshape op will be folded into its consumers. + All other cases are undefined behavior and a reshape op may not lower to + LLVM if it cannot be proven statically that it does not require alloc+copy. + + Example: + + ```mlir + %1 = linalg.reshape %0 (i, j, k) -> (i, j), (i, j, k) -> (k) : + memref into memref + ``` + }]; + + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value view, " + "ArrayAttr reassociation, ArrayRef attrs = {}">]; + + let extraClassDeclaration = [{ + static StringRef getReassociationAttrName() { return "reassociation"; } + MemRefType getViewType() { return view()->getType().cast(); } + }]; +} + def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, Arguments<(ins AnyStridedMemRef:$view, Variadic>:$indexings)>, Results<(outs AnyStridedMemRef)> { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -376,6 +376,10 @@ Block::iterator insertPoint; }; +/// Helper functions assert Attribute of the proper type in attr and returns the +/// corresponding vector. +SmallVector getAffineMaps(ArrayAttr attrs); + } // namespace mlir #endif diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -245,6 +245,9 @@ /// Whether the given dimension size indicates a dynamic dimension. static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; } + static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) { + return dStrideOrOffset == kDynamicStrideOrOffset; + } }; /// Vector types represent multi-dimensional SIMD vectors, and have a fixed diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -305,6 +305,160 @@ parser.addTypeToList(type, result.types)); } +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +/// Return true if the reassociation specification is valid, false otherwise. +/// When false, the `invalidIndex` integer pointer is optionally filled with the +/// index of the offending reassociation map. +static bool isReassociationValid(ArrayRef reassociation, + int *invalidIndex = nullptr) { + if (reassociation.empty()) + return true; + unsigned nDims = reassociation[0].getNumDims(); + unsigned nextExpectedDim = 0; + for (auto it : llvm::enumerate(reassociation)) { + auto m = it.value(); + if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { + if (invalidIndex) + *invalidIndex = it.index(); + return false; + } + for (auto e : m.getResults()) { + auto d = e.dyn_cast(); + if (!d || d.getPosition() != nextExpectedDim++) { + if (invalidIndex) + *invalidIndex = it.index(); + return false; + } + } + } + if (nextExpectedDim != nDims) { + if (invalidIndex) + *invalidIndex = reassociation.size() - 1; + return false; + } + return true; +} + +/// Detect whether memref dims [dim, dim + extent) can be reshaped without +/// copies. +static bool isReshapableDimBand(unsigned dim, unsigned extent, + ArrayRef sizes, + ArrayRef strides) { + assert(sizes.size() == strides.size() && "mismatched ranks"); + assert(sizes.size() == strides.size() && "mismatched ranks"); + // off by 1 indexing to avoid out of bounds + // V + for (auto idx = dim; idx + 1 < dim + extent; ++idx) { + // All values involved must be static, symbolic dims would require more + // information in the type. + if (ShapedType::isDynamic(sizes[dim + 1]) || + ShapedType::isDynamicStrideOrOffset(strides[dim]) || + ShapedType::isDynamicStrideOrOffset(strides[dim + 1])) + return false; + if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) + return false; + } + return true; +} + +static MemRefType computeReshapeResultType(MemRefType type, + ArrayRef reassociation) { + auto sizes = type.getShape(); + int64_t offset; + SmallVector strides; + auto status = getStridesAndOffset(type, strides, offset); + (void)status; + assert(succeeded(status) && "expected strided memref"); + + SmallVector newSizes, newStrides; + newSizes.reserve(reassociation.size()); + newStrides.reserve(reassociation.size()); + unsigned currentDim = 0; + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + for (auto m : reassociation) { + unsigned dim = m.getNumResults(); + int64_t size = 1, stride = strides[currentDim]; + if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { + size = ShapedType::kDynamicSize; + stride = ShapedType::kDynamicStrideOrOffset; + } else { + for (unsigned d = 0; d < dim; ++d) + size *= sizes[currentDim + d]; + } + newSizes.push_back(size); + newStrides.push_back(stride); + currentDim += dim; + } + auto layout = + makeStridedLinearLayoutMap(newStrides, offset, type.getContext()); + return MemRefType::get(newSizes, type.getElementType(), {layout}); +} + +void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result, + Value view, ArrayAttr reassociation, + ArrayRef attrs) { + auto maps = getAffineMaps(reassociation); + auto memRefType = view->getType().cast(); + auto resultType = computeReshapeResultType(memRefType, maps); + build(b, result, resultType, view, attrs); + result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation); +} + +static void print(OpAsmPrinter &p, ReshapeOp op) { + p << op.getOperationName() << " " << *op.view() << " " << op.reassociation(); + p.printOptionalAttrDict(op.getAttrs(), + {ReshapeOp::getReassociationAttrName()}); + p << " : " << op.getViewType() << " into " << op.getResult().getType(); +} + +static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType view; + ArrayAttr reassociation; + MemRefType type, resultType; + return failure(parser.parseOperand(view) || + parser.parseAttribute(reassociation, + ReshapeOp::getReassociationAttrName(), + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.parseKeywordType("into", resultType) || + parser.resolveOperand(view, type, result.operands) || + parser.addTypeToList(resultType, result.types)); +} + +static LogicalResult verify(ReshapeOp op) { + unsigned sourceRank = op.getViewType().getRank(); + unsigned resultRank = op.getResult().getType().cast().getRank(); + if (sourceRank == 0 || resultRank == 0) + return op.emitOpError("expected non-zero memref ranks"); + if (resultRank != op.reassociation().size()) + return op.emitOpError("expected rank of the result view(") + << resultRank << ") to be the number of reassociation maps(" + << op.reassociation().size() << ")"; + auto maps = getAffineMaps(op.reassociation()); + for (auto it : llvm::enumerate(maps)) + if (it.value().getNumDims() != sourceRank) + return op.emitOpError("expected reassociation map #") + << it.index() << " of same rank as source memref(" << sourceRank + << "), but got " << it.value().getNumDims(); + int invalidIdx = 0; + if (!isReassociationValid(maps, &invalidIdx)) + return op.emitOpError("expected reassociation map #") + << invalidIdx << " to be valid and contiguous"; + MemRefType expectedResultType = + computeReshapeResultType(op.getViewType(), maps); + if (op.getResult().getType() != expectedResultType) + return op.emitOpError("expected result to be ") + << expectedResultType << ", but got " << op.getResult().getType(); + return success(); +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -396,3 +396,8 @@ return success(); } + +SmallVector mlir::getAffineMaps(ArrayAttr attrs) { + return functional::map( + [](Attribute a) { return a.cast().getValue(); }, attrs); +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -437,3 +437,42 @@ // expected-error @+1 {{expected valid keyword}} !invalid_type = type !linalg<"?"> + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected non-zero memref ranks}} + %0 = linalg.reshape %arg0 [()->(0)] : memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected rank of the result view(2) to be the number of reassociation maps(1)}} + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j)] : + memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected reassociation map #0 of same rank as source memref(3), but got 1}} + %0 = linalg.reshape %arg0 [(i) -> (i), (i, j, k) -> (k)] : + memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}} + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), (i, j, k) -> (k, j)] : + memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected result to be 'memref (d0 * s0 + d1)>', but got 'memref'}} + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + memref into memref +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -13,6 +13,10 @@ // CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1) // CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0) +// CHECK-DAG: #[[strided2D_off0:.*]] = (d0, d1)[s0] -> (d0 * s0 + d1) +// CHECK-DAG: #[[reshapeD0:.*]] = (d0, d1, d2) -> (d0, d1) +// CHECK-DAG: #[[reshapeD1:.*]] = (d0, d1, d2) -> (d2) + func @range(%arg0: index, %arg1: index, %arg2: index) { %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range return @@ -172,3 +176,12 @@ // CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: linalg.yield %{{.*}} : f32 // CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, memref + +func @reshape(%arg0: memref) -> memref { + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + memref into memref + return %0: memref +} +// CHECK-LABEL: func @reshape +// CHECK: linalg.reshape {{.*}} [#[[reshapeD0]], #[[reshapeD1]]] +// CHECK-SAME: memref into memref