diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -94,7 +94,18 @@ static_cast(symbolOperands.size())})); if (alignment) $_state.addAttribute(getAlignmentAttrStrName(), alignment); - }]>]; + }]>, + OpBuilder<(ins "ArrayRef":$sizes, "Type":$elementType, + CArg<"Attribute", "{}">:$memorySpace), [{ + SmallVector staticShape; + SmallVector dynamicSizes; + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape); + MemRefLayoutAttrInterface layout; + MemRefType memrefType = MemRefType::get(staticShape, elementType, layout, + memorySpace); + return build($_builder, $_state, memrefType, dynamicSizes); + }]> + ]; let extraClassDeclaration = [{ static StringRef getAlignmentAttrStrName() { return "alignment"; } diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -93,4 +93,45 @@ ::mlir::transform::TransformState &state); }]; } + +def MemRefMakeLoopIndependentOp + : Op { + let description = [{ + Rewrite the targeted ops such that their index-typed operands no longer + depend on any loop induction variable of the `num_loop` enclosing `scf.for` + loops. I.e., compute an upper bound that is independent of any such loop IV + for every tensor dimension. The transformed op could then be hoisted from + the `num_loop` enclosing loops. To preserve the original semantics, place a + `memref.subview` inside the loop. + + Currently supported operations are: + - memref.alloca: Replaced with a new memref.alloca with upper bound sizes, + followed by a memref.subview. + + #### Return modes + + This operation fails if at least one induction variable could not be + eliminated. In case the targeted op is already independent of induction + variables, this transform succeeds and returns the unmodified target op. + + Otherwise, the returned handle points to a subset of the produced ops: + - memref.alloca: The returned handle points to the memref.subview op. + + This transform op consumes the target handle and produces a result handle. + }]; + + let arguments = (ins PDL_Operation:$target, I64Attr:$num_loops); + let results = (outs PDL_Operation:$transformed); + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // MEMREF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -17,8 +17,11 @@ #include "mlir/Support/LogicalResult.h" namespace mlir { +class OpBuilder; class RewritePatternSet; class RewriterBase; +class Value; +class ValueRange; namespace arith { class WideIntEmulationConverter; @@ -26,6 +29,8 @@ namespace memref { class AllocOp; +class AllocaOp; + //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// @@ -121,6 +126,60 @@ /// ``` void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns); +/// Build a new memref::AllocaOp who's dynamic sizes are independent of all +/// given independencies. If the op is already independent of all +/// independencies, the same AllocaOp result is returned. +/// +/// Failure indicates the no suitable upper bound for the dynamic sizes could be +/// found. +FailureOr buildIndependentOp(OpBuilder &b, AllocaOp allocaOp, + ValueRange independencies); + +/// Build a new memref::AllocaOp who's dynamic sizes are independent of all +/// given independencies. If the op is already independent of all +/// independencies, the same AllocaOp result is returned. +/// +/// The original AllocaOp is replaced with the new one, wrapped in a SubviewOp. +/// The result type of the replacement is different from the original allocation +/// type: it has the same shape, but a different layout map. This function +/// updates all users that do not have a memref result or memref region block +/// argument, and some frequently used memref dialect ops (such as +/// memref.subview). It does not update other uses such as the init_arg of an +/// scf.for op. Such uses are wrapped in unrealized_conversion_cast. +/// +/// Failure indicates the no suitable upper bound for the dynamic sizes could be +/// found. +/// +/// Example (make independent of %iv): +/// ``` +/// scf.for %iv = %c0 to %sz step %c1 { +/// %0 = memref.alloca(%iv) : memref +/// %1 = memref.subview %0[0][5][1] : ... +/// linalg.generic outs(%1 : ...) ... +/// %2 = scf.for ... iter_arg(%arg0 = %0) ... +/// ... +/// } +/// ``` +/// +/// The above IR is rewritten to: +/// +/// ``` +/// scf.for %iv = %c0 to %sz step %c1 { +/// %0 = memref.alloca(%sz) : memref +/// %0_subview = memref.subview %0[0][%iv][1] +/// : memref to memref +/// %1 = memref.subview %0_subview[0][5][1] : ... +/// linalg.generic outs(%1 : ...) ... +/// %cast = unrealized_conversion_cast %0_subview +/// : memref to memref +/// %2 = scf.for ... iter_arg(%arg0 = %cast) ... +/// ... +/// } +/// ``` +FailureOr replaceWithIndependentOp(RewriterBase &rewriter, + memref::AllocaOp allocaOp, + ValueRange independencies); + } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -97,6 +98,49 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// MemRefMakeLoopIndependentOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + // Gather IVs. + SmallVector ivs; + Operation *nextOp = target; + for (uint64_t i = 0; i < getNumLoops(); ++i) { + nextOp = nextOp->getParentOfType(); + if (!nextOp) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "could not find " << i + << "-th enclosing loop"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + ivs.push_back(cast(nextOp).getInductionVar()); + } + + // Rewrite IR. + IRRewriter rewriter(target->getContext()); + FailureOr replacement = failure(); + if (auto allocaOp = dyn_cast(target)) { + replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); + } else { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "unsupported target op"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + if (failed(replacement)) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "could not make target op loop-independent"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + results.push_back(replacement->getDefiningOp()); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ EmulateWideInt.cpp ExtractAddressComputations.cpp FoldMemRefAliasOps.cpp + IndependenceTransforms.cpp MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp @@ -19,6 +20,7 @@ LINK_LIBS PUBLIC MLIRAffineDialect + MLIRAffineTransforms MLIRAffineUtils MLIRArithDialect MLIRArithTransforms @@ -33,6 +35,7 @@ MLIRPass MLIRTensorDialect MLIRTransforms + MLIRValueBoundsOpInterface MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -0,0 +1,176 @@ +//===- IndependenceTransforms.cpp - Make ops independent of values --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; +using namespace mlir::memref; + +/// Make the given OpFoldResult independent of all independencies. +static FailureOr makeIndependent(OpBuilder &b, Location loc, + OpFoldResult ofr, + ValueRange independencies) { + if (ofr.is()) + return ofr; + Value value = ofr.get(); + AffineMap boundMap; + ValueDimList mapOperands; + if (failed(ValueBoundsConstraintSet::computeIndependentBound( + boundMap, mapOperands, presburger::BoundType::UB, value, + /*dim=*/std::nullopt, independencies, /*closedUB=*/true))) + return failure(); + return affine::materializeComputedBound(b, loc, boundMap, mapOperands); +} + +FailureOr memref::buildIndependentOp(OpBuilder &b, + memref::AllocaOp allocaOp, + ValueRange independencies) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(allocaOp); + Location loc = allocaOp.getLoc(); + + SmallVector newSizes; + for (OpFoldResult ofr : allocaOp.getMixedSizes()) { + auto ub = makeIndependent(b, loc, ofr, independencies); + if (failed(ub)) + return failure(); + newSizes.push_back(*ub); + } + + // Return existing memref::AllocaOp if nothing has changed. + if (llvm::equal(allocaOp.getMixedSizes(), newSizes)) + return allocaOp.getResult(); + + // Create a new memref::AllocaOp. + Value newAllocaOp = + b.create(loc, newSizes, allocaOp.getType().getElementType()); + + // Create a memref::SubViewOp. + SmallVector offsets(newSizes.size(), b.getIndexAttr(0)); + SmallVector strides(newSizes.size(), b.getIndexAttr(1)); + return b + .create(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(), + strides) + .getResult(); +} + +/// Given an original op and a new, modified op with the same number of results, +/// whose memref return types may differ, replace all uses of the original op +/// with the new op and propagate the new memref types through the IR. +/// +/// Example: +/// %from = memref.alloca(%sz) : memref +/// %to = memref.subview ... : ... to memref> +/// memref.store %cst, %from[%c0] : memref +/// +/// In the above example, all uses of %from are replaced with %to. This can be +/// done directly for ops such as memref.store. For ops that have memref results +/// (e.g., memref.subview), the result type may depend on the operand type, so +/// we cannot just replace all uses. There is special handling for common memref +/// ops. For all other ops, unrealized_conversion_cast is inserted. +static void replaceAndPropagateMemRefType(RewriterBase &rewriter, + Operation *from, Operation *to) { + assert(from->getNumResults() == to->getNumResults() && + "expected same number of results"); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(to); + + // Wrap new results in unrealized_conversion_cast and replace all uses of the + // original op. + SmallVector unrealizedConversions; + for (const auto &it : + llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) { + unrealizedConversions.push_back(rewriter.create( + to->getLoc(), std::get<0>(it.value()).getType(), + std::get<1>(it.value()))); + rewriter.replaceAllUsesWith(from->getResult(it.index()), + unrealizedConversions.back()->getResult(0)); + } + + // Push unrealized_conversion_cast ops further down in the IR. I.e., try to + // wrap results instead of operands in a cast. + for (int i = 0; i < unrealizedConversions.size(); ++i) { + UnrealizedConversionCastOp conversion = unrealizedConversions[i]; + assert(conversion->getNumOperands() == 1 && + conversion->getNumResults() == 1 && + "expected single operand and single result"); + SmallVector users = llvm::to_vector(conversion->getUsers()); + for (Operation *user : users) { + // Handle common memref dialect ops that produce new memrefs and must + // be recreated with the new result type. + if (auto subviewOp = dyn_cast(user)) { + rewriter.setInsertionPoint(subviewOp); + auto newResultType = + SubViewOp::inferRankReducedResultType( + subviewOp.getType().getShape(), subviewOp.getSourceType(), + subviewOp.getMixedOffsets(), subviewOp.getMixedSizes(), + subviewOp.getMixedStrides()) + .cast(); + Value newSubview = rewriter.create( + subviewOp.getLoc(), newResultType, conversion.getOperand(0), + subviewOp.getMixedOffsets(), subviewOp.getMixedSizes(), + subviewOp.getMixedStrides()); + auto conversionOp = rewriter.create( + subviewOp.getLoc(), subviewOp.getType(), newSubview); + rewriter.replaceAllUsesWith(subviewOp.getResult(), + conversionOp->getResult(0)); + unrealizedConversions.push_back(conversionOp); + continue; + } + + // TODO: Other memref ops such as memref.collapse_shape/expand_shape + // should also be handled here. + + // Skip any ops that produce MemRef result or have MemRef region block + // arguments. These may need special handling (e.g., scf.for). + if (llvm::any_of(user->getResultTypes(), + [](Type t) { return isa(t); })) + continue; + if (llvm::any_of(user->getRegions(), [](Region &r) { + return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) { + return isa(bbArg.getType()); + }); + })) + continue; + + // For all other ops, we assume that we can directly replace the operand. + // This may have to be revised in the future; e.g., there may be ops that + // do not support non-identity layout maps. + for (OpOperand &operand : user->getOpOperands()) { + if (auto castOp = + operand.get().getDefiningOp()) { + rewriter.updateRootInPlace( + user, [&]() { operand.set(conversion->getOperand(0)); }); + } + } + } + } + + // Erase all unrealized_conversion_cast ops without uses. + for (auto op : unrealizedConversions) + if (op->getUses().empty()) + rewriter.eraseOp(op); +} + +FailureOr memref::replaceWithIndependentOp(RewriterBase &rewriter, + memref::AllocaOp allocaOp, + ValueRange independencies) { + auto replacement = + memref::buildIndependentOp(rewriter, allocaOp, independencies); + if (failed(replacement)) + return failure(); + replaceAndPropagateMemRefType(rewriter, allocaOp, + replacement->getDefiningOp()); + return replacement; +} diff --git a/mlir/test/Dialect/MemRef/make-loop-independent.mlir b/mlir/test/Dialect/MemRef/make-loop-independent.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/make-loop-independent.mlir @@ -0,0 +1,74 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect \ +// RUN: -test-transform-dialect-interpreter -canonicalize \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 - 1)> +// CHECK-LABEL: func @make_alloca_loop_independent( +// CHECK-SAME: %[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index) +func.func @make_alloca_loop_independent(%lb: index, %ub: index, %step: index) { + %cst = arith.constant 5.5 : f32 + %c0 = arith.constant 0 : index + // CHECK: scf.for %[[iv:.*]] = %[[lb]] to %[[ub]] + scf.for %i = %lb to %ub step %step { + // CHECK: %[[sz:.*]] = affine.apply #[[$map]]()[%[[ub]]] + // CHECK: %[[alloca:.*]] = memref.alloca(%[[sz]]) + // CHECK: %[[subview:.*]] = memref.subview %[[alloca]][0] [%[[iv]]] [1] : memref to memref> + // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[subview]] : memref> to memref + %alloc = memref.alloca(%i) : memref + + // memref.subview has special handling. + // CHECK: %[[subview2:.*]] = memref.subview %[[subview]][1] [5] [1] : memref> to memref<5xf32, strided<[1], offset: 1>> + %view = memref.subview %alloc[1][5][1] : memref to memref<5xf32, strided<[1], offset: 1>> + + // This op takes a memref but does not produce one. The new alloc is used + // directly. + // CHECK: "test.some_use"(%[[subview2]]) + "test.some_use"(%view) : (memref<5xf32, strided<[1], offset: 1>>) -> () + + // This op produces a memref, so the new alloc cannot be used directly. + // It is wrapped in a unrealized_conversion_cast. + // CHECK: "test.another_use"(%[[cast]]) : (memref) -> memref + "test.another_use"(%alloc) : (memref) -> (memref) + + // CHECK: memref.store %{{.*}}, %[[subview]] + memref.store %cst, %alloc[%c0] : memref + } + return +} +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloca"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.make_loop_independent %0 {num_loops = 1} +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<(d0) -> (-d0 + 128)> +// CHECK-LABEL: func @make_alloca_loop_independent_static( +func.func @make_alloca_loop_independent_static(%step: index) { + %cst = arith.constant 5.5 : f32 + %c0 = arith.constant 0 : index + %ub = arith.constant 128 : index + // CHECK: scf.for %[[iv:.*]] = + scf.for %i = %c0 to %ub step %step { + // CHECK: %[[sz:.*]] = affine.apply #[[$map]](%[[iv]]) + %sz = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%i)[%ub] + + // CHECK: %[[alloca:.*]] = memref.alloca() : memref<128xf32> + // CHECK: %[[subview:.*]] = memref.subview %[[alloca]][0] [%[[sz]]] [1] : memref<128xf32> to memref> + %alloc = memref.alloca(%sz) : memref + + // CHECK: memref.store %{{.*}}, %[[subview]] + memref.store %cst, %alloc[%c0] : memref + + // CHECK: vector.print %[[sz]] + %dim = memref.dim %alloc, %c0 : memref + vector.print %dim : index + } + return +} +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloca"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.make_loop_independent %0 {num_loops = 1} +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10514,6 +10514,7 @@ includes = ["include"], deps = [ ":AffineDialect", + ":AffineTransforms", ":AffineUtils", ":ArithDialect", ":ArithTransforms", @@ -10534,6 +10535,7 @@ ":Support", ":TensorDialect", ":Transforms", + ":ValueBoundsOpInterface", ":VectorDialect", "//llvm:Support", ], @@ -10586,6 +10588,7 @@ ":MemRefTransforms", ":NVGPUDialect", ":PDLDialect", + ":SCFDialect", ":TransformDialect", ":TransformUtils", ":VectorDialect",