diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -27,6 +27,8 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" + namespace mlir { namespace linalg { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -213,6 +213,19 @@ return {range.begin(), range.begin() + $_op.getNumInputs()}; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the range over the input operands that are of buffer type. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getInputBuffers", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::to_vector<4>(llvm::make_filter_range( + getInputs(), [](Value in){ return in.getType().isa(); })); + }] + >, InterfaceMethod< /*desc=*/[{ Return the subset of input operands that are of ranked tensor type. @@ -416,6 +429,20 @@ return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors(); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the `i`-th shaped operand value, which can be an arbitrary input + tensor/buffer, init tensor or output buffer. + }], + /*retTy=*/"Value", + /*methodName=*/"getShapedOperand", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < $_op.getNumShapedOperands()); + return this->getOperation()->getOperand(i); + }] + >, InterfaceMethod< /*desc=*/[{ Return the range over inputs, output buffers and init tensors. diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -84,10 +84,14 @@ /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. -Optional fuseProducerOf(OpBuilder &b, LinalgOp consumer, - unsigned consumerIdx, - const LinalgDependenceGraph &graph, - OperationFolder *folder = nullptr); +Optional fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer, + unsigned consumerIdx, + const LinalgDependenceGraph &graph, + OperationFolder *folder = nullptr); +/// Tensor counterpart of `fuseProducerOfBuffer`. +Optional fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, + unsigned consumerIdx, + OperationFolder *folder); /// Fuse linalg operation on tensors, with the producer of the operand at /// position `consumerIdx` of the consumer. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -147,13 +147,9 @@ } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - assert(src.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(dst.hasBufferSemantics() && - "expected linalg op with buffer semantics"); for (auto srcView : src.getOutputBuffers()) { // W // RAW graph - for (auto dstView : dst.getInputs()) { // R + for (auto dstView : dst.getInputBuffers()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAW addDependenceElem(DependenceType::RAW, LinalgOpView{src.getOperation(), srcView}, @@ -169,9 +165,9 @@ } } } - for (auto srcView : src.getInputs()) { // R + for (auto srcView : src.getInputBuffers()) { // R // RAR graph - for (auto dstView : dst.getInputs()) { // R + for (auto dstView : dst.getInputBuffers()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAR addDependenceElem(DependenceType::RAR, LinalgOpView{src.getOperation(), srcView}, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -59,16 +59,16 @@ // a subset of the original loop ranges of `op`. // This is achieved by applying the `loopToOperandRangesMaps` permutation maps // to the `loopRanges` in order to obtain view ranges. -static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, - ArrayRef loopRanges) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); +static LinalgOp cloneWithLoopRangesAndTypes(OpBuilder &b, Location loc, + LinalgOp op, + ArrayRef loopRanges, + TypeRange resultTypes) { auto maps = op.indexing_maps(); SmallVector clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputBuffers()); - for (auto en : llvm::enumerate(ios)) { + for (auto en : llvm::enumerate(op.getShapedOperands())) { unsigned idx = en.index(); auto map = maps[idx].cast().getValue(); LLVM_DEBUG(dbgs() << "map: " << map << "\n"); @@ -94,13 +94,17 @@ sizes.push_back(r.size); strides.push_back(r.stride); } - clonedViews.push_back( - b.create(loc, view, offsets, sizes, strides)); + if (view.getType().isa()) + clonedViews.push_back( + b.create(loc, view, offsets, sizes, strides)); + else + clonedViews.push_back( + b.create(loc, view, offsets, sizes, strides)); } auto operands = op.getAssumedNonShapedOperands(); clonedViews.append(operands.begin(), operands.end()); - Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews); + Operation *clonedOp = op.clone(b, loc, resultTypes, clonedViews); // When the producer is an IndexedGenercOp, we have to transform its block // IV arguments according to the tiling of the consumer, i.e. offset them by // the values computed in `loopRanges`. @@ -131,7 +135,6 @@ // they must agree by construction (i.e. have the same size) and we just return // the first one. static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto maps = op.indexing_maps(); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. @@ -155,25 +158,29 @@ llvm_unreachable("Expect to be able to extract a view defining loop range"); } +/// Fuses the producer of `producerIdx` into the loop immediately enclosing +/// `consumer`. This is achieved by "recomputing" the `producer` at the time it +/// is needed just before the `consumer. +/// +/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are +/// 2 cases: +/// 1. Buffer case: `producerIdx` is the index of the buffer in +/// `producer.getShapedOperands()`. +/// 2. Tensor case: `producerIdx` is the index of the tensor in +/// `producer.getResults()`. static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, LinalgOp consumer, unsigned consumerIdx, OperationFolder *folder = nullptr) { - assert(producer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - - auto subView = dyn_cast_or_null( - consumer.getBuffer(consumerIdx).getDefiningOp()); - auto slice = dyn_cast_or_null( - consumer.getBuffer(consumerIdx).getDefiningOp()); - assert(subView || slice); - (void)subView; - (void)slice; + Operation *shapeProducingOp = + consumer.getShapedOperand(consumerIdx).getDefiningOp(); + assert((isa(shapeProducingOp) || + isa(shapeProducingOp)) && + "SubviewOp or SubTensorOp expected"); // loopToOperandRangesMaps are permutations-only by construction: // we can always identify a data dimension with a (at least one) loop // dimension. + // TODO: extend this with range inference. AffineMap producerMap = producer.indexing_maps()[producerIdx].cast().getValue(); LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx @@ -190,7 +197,11 @@ for (auto en : llvm::enumerate(producerMap.getResults())) { unsigned posInProducerLoop = en.value().cast().getPosition(); loopRanges[posInProducerLoop] = - subView.getOrCreateRanges(b, loc)[en.index()]; + isa(shapeProducingOp) + ? cast(shapeProducingOp) + .getOrCreateRanges(b, loc)[en.index()] + : cast(shapeProducingOp) + .getOrCreateRanges(b, loc)[en.index()]; } // Iterate over all dimensions. For the dimensions not identified by the @@ -209,7 +220,15 @@ } } - return cloneWithLoopRanges(b, loc, producer, loopRanges); + // Pass updated resultTypes to support the tensor case: just copy the + // corresponding tensor type from the consumer if appropriate. + SmallVector resultTypes = + llvm::to_vector<4>(producer.getOperation()->getResultTypes()); + RankedTensorType tensorType = + consumer.getShapedType(consumerIdx).dyn_cast(); + if (tensorType) + resultTypes[producerIdx] = tensorType; + return cloneWithLoopRangesAndTypes(b, loc, producer, loopRanges, resultTypes); } // Encode structural fusion safety preconditions. @@ -354,7 +373,7 @@ return {}; } -Optional mlir::linalg::fuseProducerOf( +Optional mlir::linalg::fuseProducerOfBuffer( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, const LinalgDependenceGraph &graph, OperationFolder *folder) { Optional fusableDependence = @@ -390,6 +409,63 @@ return FusionInfo{producerOp, fusedProducer}; } +/// Walk back use-def chain through scf::For yields. +/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp +static void getProducerOfTensor(Value tensor, LinalgOp &producer, + unsigned &outputIndex) { + if (!tensor.getType().isa()) + return; + + while (true) { + if (auto linalgOp = tensor.getDefiningOp()) { + producer = linalgOp; + outputIndex = tensor.cast().getResultNumber(); + return; + } + if (auto subTensorOp = tensor.getDefiningOp()) { + tensor = subTensorOp.source(); + continue; + } + if (auto blockArg = tensor.dyn_cast()) { + if (auto forOp = blockArg.getDefiningOp()) { + tensor = forOp.getResult(blockArg.getArgNumber()); + continue; + } + } + return; + } +} + +Optional +mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, + unsigned consumerIdx, + OperationFolder *folder) { + Value inputTensor = consumer.getInput(consumerIdx); + LinalgOp producerOp; + unsigned producerIdx; + getProducerOfTensor(inputTensor, producerOp, producerIdx); + + // Must be a subtensor to guarantee there are loops we can fuse into. + auto subTensor = inputTensor.getDefiningOp(); + if (!subTensor || !producerOp) { + LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)"); + return {}; + } + + // Insert fused `producer` just before `consumer`. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(consumer.getOperation()); + ScopedContext scope(b, consumer.getLoc()); + LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); + auto fusedProducer = + fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); + // Replace use. + consumer.getOperation() + ->getOpOperand(consumerIdx) + .set(fusedProducer.getOperation()->getResult(producerIdx)); + return FusionInfo{producerOp, fusedProducer}; +} + /// Returns the positions of the loop in `op` that can be tiled based on the /// operations that are to be fused with it. For example, in a /// @@ -702,35 +778,48 @@ // Save original Linalg ops, we only want to make a pass over those. SmallVector linalgOps; - f.walk([&](LinalgOp op) { - if (op.hasBufferSemantics()) - linalgOps.push_back(op); - }); - - // TODO: LinalgDependenceGraph should be able to update itself. - // The current naive and expensive reconstruction of the graph should be - // removed. + f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); + + // Tile and Fuse for tensors inputs (TODO: all tensor operands). for (auto *op : llvm::reverse(linalgOps)) { - for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); - id < e; ++id) { - linalg::Aliases aliases; - linalg::LinalgDependenceGraph graph(aliases, linalgOps); - if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { - auto *originalOp = info->originalProducer.getOperation(); - eraseSet.insert(originalOp); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + LinalgOp linalgOp = cast(op); + for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { + if (en.value().getType().isa()) { + // TODO: LinalgDependenceGraph should be able to update itself. + // The current naive and expensive reconstruction of the graph should be + // removed. + linalg::Aliases aliases; + linalg::LinalgDependenceGraph graph(aliases, linalgOps); + if (auto info = + fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) { + auto *originalOp = info->originalProducer.getOperation(); + eraseSet.insert(originalOp); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + } + } else { + assert(en.value().getType().isa()); + // Tile and Fuse tensor input (TODO: init_tensors too). + if (en.index() >= linalgOp.getNumInputs()) + continue; + if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) { + auto *originalOp = info->originalProducer.getOperation(); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + } } } } - // The `fuseProducerOf` function performs structural checks and in particular - // that no covering read or write exist between the consumer and the producer. - // As a consequence, the only fusions that may occur preserve subsequent - // dependences and are guaranteed by construction to produce the whole view. - // We may thus erase the producer once it is fused. + // The `fuseProducerOfBuffer` function performs structural checks and in + // particular that no covering read or write exist between the consumer and + // the producer. As a consequence, the only fusions that may occur preserve + // subsequent dependences and are guaranteed by construction to produce the + // whole view. We may thus erase the producer once it is fused. for (auto *e : eraseSet) e->erase(); + LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); } diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s + +#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)> +#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)> +#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)> +#map3 = affine_map<(d0, d1) -> (2, d0 - d1)> +#map4 = affine_map<(d0, d1) -> (3, d0 - d1)> + +func @matmul_tensors(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %t0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) + init(%arg2: tensor) + -> tensor + + %c4 = constant 4 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %c3 = constant 3 : index + %c1 = constant 1 : index + %0 = dim %t0, %c0 : tensor + %1 = dim %t0, %c1 : tensor + %2 = dim %arg1, %c1 : tensor + %3 = scf.for %arg3 = %c0 to %0 step %c2 iter_args(%arg4 = %arg2) -> (tensor) { + %4 = scf.for %arg5 = %c0 to %2 step %c3 iter_args(%arg6 = %arg4) -> (tensor) { + %5 = scf.for %arg7 = %c0 to %1 step %c4 iter_args(%arg8 = %arg6) -> (tensor) { + %6 = dim %t0, %c0 : tensor + %7 = affine.min #map0(%arg3)[%6] + %8 = dim %t0, %c1 : tensor + %9 = affine.min #map1(%arg7)[%8] + %10 = subtensor %t0[%arg3, %arg7] [%7, %9] [1, 1] : tensor to tensor + %11 = dim %arg1, %c0 : tensor + %12 = affine.min #map1(%arg7)[%11] + %13 = dim %arg1, %c1 : tensor + %14 = affine.min #map2(%arg5)[%13] + %15 = subtensor %arg1[%arg7, %arg5] [%12, %14] [1, 1] : tensor to tensor + %16 = dim %arg8, %c0 : tensor + %17 = affine.min #map3(%16, %arg3) + %18 = dim %arg8, %c1 : tensor + %19 = affine.min #map4(%18, %arg5) + %20 = subtensor %arg8[%arg3, %arg5] [%17, %19] [1, 1] : tensor to tensor + %21 = linalg.matmul ins(%10, %15 : tensor, tensor) init(%20 : tensor) -> tensor + %22 = subtensor_insert %21 into %arg8[%arg3, %arg5] [%17, %19] [%c1, %c1] : tensor into tensor + scf.yield %22 : tensor + } + scf.yield %5 : tensor + } + scf.yield %4 : tensor + } + return %3 : tensor +} + +// CHECK-LABEL: func @matmul_tensors( + // CHECK-SAME: %[[A:[0-9a-z]*]]: tensor + // CHECK-SAME: %[[B:[0-9a-z]*]]: tensor + // CHECK-SAME: %[[C:[0-9a-z]*]]: tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: scf.for %[[I:[0-9a-z]*]] +// CHECK-NEXT: scf.for %[[J:[0-9a-z]*]] +// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] +// +// subtensor of the original program, first one refers to the unfused matmul and becomes a dead SSA value. +// CHECK: subtensor %{{.*}}[%[[I]], %[[K]]] {{.*}} : tensor to tensor +// CHECK: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] {{.*}} : tensor to tensor +// CHECK: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] {{.*}} : tensor to tensor +// +// subtensors of the producing matmul. +// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], %[[K]]] {{.*}} : tensor to tensor +// CHECK-NEXT: %[[stB2:.*]] = subtensor %[[B]][%[[K]], %[[C0]]] {{.*}} : tensor to tensor +// CHECK-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[C0]]] {{.*}} : tensor to tensor +// CHECK-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor, tensor) init(%[[stC]] : tensor) -> tensor +// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor, tensor) init(%[[stF]] : tensor) -> tensor +// CHECK-NEXT: subtensor_insert %[[stG]]