diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -17,7 +17,7 @@ namespace intrinsics { using linalg_fill = OperationBuilder; -using linalg_reshape = OperationBuilder; +using linalg_reshape = ValueBuilder; using linalg_yield = OperationBuilder; } // namespace intrinsics diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -100,9 +100,17 @@ ``` }]; - let builders = [OpBuilder< - "Builder *b, OperationState &result, Value view, " - "ArrayAttr reassociation, ArrayRef attrs = {}">]; + let builders = [ + // Builder for a contracting reshape whose result type is computed from + // `view` and `reassociation`. + OpBuilder<"Builder *b, OperationState &result, Value view, " + "ArrayRef> reassociation, " + "ArrayRef attrs = {}">, + // Builder for a reshape whose result type is passed explicitly. This may be + // either a contracting or expanding reshape. + OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view," + "ArrayRef> reassociation, " + "ArrayRef attrs = {}">]; let extraClassDeclaration = [{ static StringRef getReassociationAttrName() { return "reassociation"; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -465,14 +465,52 @@ [](Attribute a) { return a.cast().getValue(); }, attrs); } -void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result, - Value view, ArrayAttr reassociation, - ArrayRef attrs) { - auto maps = getAffineMaps(reassociation); +template +unsigned getMaxPosOfType(ArrayRef> exprArrays) { + unsigned pos = 0; + for (auto exprs : exprArrays) { + for (auto expr : exprs) { + expr.walk([&pos](AffineExpr e) { + if (auto d = e.dyn_cast()) + pos = std::max(pos, d.getPosition()); + }); + } + } + return pos; +} + +static SmallVector +getSymbolLessAffineMaps(ArrayRef> reassociation) { + unsigned maxDim = getMaxPosOfType(reassociation); + unsigned maxSym = getMaxPosOfType(reassociation); + assert(maxSym == 0 && "Expected symbol-less expressions"); + SmallVector maps; + maps.reserve(reassociation.size()); + for (auto exprs : reassociation) + maps.push_back(AffineMap::get(maxDim + 1, 0, exprs)); + return maps; +} + +void mlir::linalg::ReshapeOp::build( + Builder *b, OperationState &result, Value view, + ArrayRef> reassociation, + ArrayRef attrs) { + auto maps = getSymbolLessAffineMaps(reassociation); auto memRefType = view.getType().cast(); auto resultType = computeReshapeCollapsedType(memRefType, maps); build(b, result, resultType, view, attrs); - result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation); + result.addAttribute(ReshapeOp::getReassociationAttrName(), + b->getAffineMapArrayAttr(maps)); +} + +void mlir::linalg::ReshapeOp::build( + Builder *b, OperationState &result, Type resultType, Value view, + ArrayRef> reassociation, + ArrayRef attrs) { + auto maps = getSymbolLessAffineMaps(reassociation); + build(b, result, resultType, view, attrs); + result.addAttribute(ReshapeOp::getReassociationAttrName(), + b->getAffineMapArrayAttr(maps)); } static void print(OpAsmPrinter &p, ReshapeOp op) { diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Builders.h" @@ -962,6 +963,32 @@ f.erase(); } +// clang-format off +// CHECK-LABEL: func @linalg_metadata_ops +// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<4x8x16xf32> into memref<32x16xf32> +// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<32x16xf32> into memref<4x8x16xf32> +// clang-format on +TEST_FUNC(linalg_metadata_ops) { + using namespace edsc; + using namespace edsc::intrinsics; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0); + auto f = makeFunction("linalg_metadata_ops", {}, {memrefType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + AffineExpr i, j, k; + bindDims(&globalContext(), i, j, k); + ValueHandle v(f.getArgument(0)); + auto reshaped = linalg_reshape(v, ArrayRef>{{i, j}, k}); + linalg_reshape(memrefType, reshaped, + ArrayRef>{{i, j}, k}); + + f.print(llvm::outs()); + f.erase(); +} + int main() { RUN_TESTS(); return 0;