diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -34,6 +34,9 @@ class PoolingMinOp; class PoolingSumOp; +using ReassociationIndicies = SmallVector; +using ReassociationExprs = SmallVector; + /// Returns the name mangled library call name to disambiguate between different /// overloads at the C level. The name mangling scheme is basic and uses MLIR /// type names: diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -64,16 +64,32 @@ class Linalg_ReshapeLikeOp : Linalg_Op { let builders = [ - // Builder for a contracting reshape whose result type is computed from + // Builders for a contracting reshape whose result type is computed from // `src` and `reassociation`. OpBuilder<"OpBuilder &b, OperationState &result, Value src, " - "ArrayRef> reassociation, " - "ArrayRef attrs = {}">, - // Builder for a reshape whose result type is passed explicitly. This may be - // either a contracting or expanding reshape. - OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, Value src," - "ArrayRef> reassociation, " - "ArrayRef attrs = {}">]; + "ArrayRef reassociation, " + "ArrayRef attrs = {}">, + OpBuilder<"OpBuilder &b, OperationState &result, Value src, " + "ArrayRef reassociation, " + "ArrayRef attrs = {}", [{ + auto reassociationMaps = + convertReassociationIndiciesToMaps(b, reassociation); + build(b, result, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. This may + // be either a contracting or expanding reshape. + OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, " + "Value src, ArrayRef reassociation, " + "ArrayRef attrs = {}">, + OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, " + "Value src, ArrayRef reassociation, " + "ArrayRef attrs = {}", [{ + auto reassociationMaps = + convertReassociationIndiciesToMaps(b, reassociation); + build(b, result, src, reassociationMaps, attrs); + }]> + ]; code commonExtraClassDeclaration = [{ static StringRef getReassociationAttrName() { return "reassociation"; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -476,9 +476,9 @@ } template -unsigned getMaxPosOfType(ArrayRef> exprArrays) { +unsigned getMaxPosOfType(ArrayRef exprArrays) { unsigned pos = 0; - for (auto exprs : exprArrays) { + for (const auto &exprs : exprArrays) { for (auto expr : exprs) { expr.walk([&pos](AffineExpr e) { if (auto d = e.dyn_cast()) @@ -490,23 +490,37 @@ } static SmallVector -getSymbolLessAffineMaps(ArrayRef> reassociation) { +getSymbolLessAffineMaps(ArrayRef reassociation) { unsigned maxDim = getMaxPosOfType(reassociation); assert(getMaxPosOfType(reassociation) == 0 && "Expected symbol-less expressions"); SmallVector maps; maps.reserve(reassociation.size()); - for (auto exprs : reassociation) { - assert(exprs.size() != 0); + for (const auto &exprs : reassociation) { + assert(!exprs.empty()); maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); } return maps; } -void mlir::linalg::ReshapeOp::build( - OpBuilder &b, OperationState &result, Value src, - ArrayRef> reassociation, - ArrayRef attrs) { +static SmallVector, 2> +convertReassociationIndiciesToMaps( + OpBuilder &b, ArrayRef reassociationIndicies) { + SmallVector, 2> reassociationMaps; + for (const auto &indicies : reassociationIndicies) { + SmallVector reassociationMap; + reassociationMap.reserve(indicies.size()); + for (int64_t index : indicies) + reassociationMap.push_back(b.getAffineDimExpr(index)); + reassociationMaps.push_back(std::move(reassociationMap)); + } + return reassociationMaps; +} + +void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result, + Value src, + ArrayRef reassociation, + ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); auto memRefType = src.getType().cast(); auto resultType = computeReshapeCollapsedType(memRefType, maps); @@ -515,10 +529,10 @@ b.getAffineMapArrayAttr(maps)); } -void mlir::linalg::ReshapeOp::build( - OpBuilder &b, OperationState &result, Type resultType, Value src, - ArrayRef> reassociation, - ArrayRef attrs) { +void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result, + Type resultType, Value src, + ArrayRef reassociation, + ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); build(b, result, resultType, src, attrs); result.addAttribute(ReshapeOp::getReassociationAttrName(), @@ -622,7 +636,7 @@ void mlir::linalg::TensorReshapeOp::build( OpBuilder &b, OperationState &result, Value src, - ArrayRef> reassociation, + ArrayRef reassociation, ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); auto resultType = computeTensorReshapeCollapsedType( @@ -634,7 +648,7 @@ void mlir::linalg::TensorReshapeOp::build( OpBuilder &b, OperationState &result, Type resultType, Value src, - ArrayRef> reassociation, + ArrayRef reassociation, ArrayRef attrs) { auto maps = getSymbolLessAffineMaps(reassociation); build(b, result, resultType, src, attrs); diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -941,6 +941,8 @@ // CHECK: linalg.reshape {{.*}} [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : memref<32x16xf32> into memref<4x8x16xf32> // clang-format on TEST_FUNC(linalg_metadata_ops) { + using linalg::ReassociationExprs; + auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0); auto f = makeFunction("linalg_metadata_ops", {}, {memrefType}); @@ -950,9 +952,10 @@ AffineExpr i, j, k; bindDims(&globalContext(), i, j, k); Value v(f.getArgument(0)); - auto reshaped = linalg_reshape(v, ArrayRef>{{i, j}, k}); - linalg_reshape(memrefType, reshaped, - ArrayRef>{{i, j}, k}); + SmallVector maps = {ReassociationExprs({i, j}), + ReassociationExprs({k})}; + auto reshaped = linalg_reshape(v, maps); + linalg_reshape(memrefType, reshaped, maps); f.print(llvm::outs()); f.erase();