diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -33,6 +33,7 @@ using edsc::op::operator+; using edsc::op::operator==; +using mlir::edsc::intrinsics::detail::ValueHandleArray; static SmallVector makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, @@ -81,6 +82,30 @@ return res; } +template +static void inlineRegionAndEmitStdStore(OpType op, + ArrayRef indexedValues, + ArrayRef indexing, + ArrayRef outputBuffers) { + auto &b = ScopedContext::getBuilder(); + auto &block = op.region().front(); + BlockAndValueMapping map; + map.map(block.getArguments(), indexedValues); + for (auto &op : block.without_terminator()) { + assert(op.getNumRegions() == 0 && "expected a non-nested region"); + auto *newOp = b.clone(op, map); + map.map(op.getResults(), newOp->getResults()); + } + + Operation &terminator = block.back(); + assert(isa(terminator) && + "expected an yield op in the end of the region"); + for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { + std_store(map.lookup(terminator.getOperand(i)), outputBuffers[i], + indexing[i]); + } +} + namespace { template class LinalgScopedEmitter {}; @@ -300,6 +325,8 @@ } // 1.b. Emit std_load from output views. + // TODO(mravishankar): Avoid the loads if the corresponding argument of the + // region has no uses. for (unsigned i = 0; i < nOutputs; ++i) { Value output = genericOp.getOutputBuffer(i); ValueHandleArray indexing(makeCanonicalAffineApplies( @@ -324,24 +351,16 @@ } // TODO(ntv): When a region inliner exists, use it. // 2. Inline region, currently only works for a single basic block. - BlockAndValueMapping map; - auto &block = genericOp.region().front(); - map.map(block.getArguments(), indexedValues); - for (auto &op : block.without_terminator()) { - assert(op.getNumRegions() == 0); - auto *newOp = b.clone(op, map); - map.map(op.getResults(), newOp->getResults()); - } - // 3. Emit std_store. - auto *yieldOp = cast(block.back()).getOperation(); - assert(yieldOp->getNumOperands() == nOutputs); + SmallVector indexing; + SmallVector outputBuffers; for (unsigned i = 0; i < nOutputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( + indexing.emplace_back(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), - genericOp.getOutputBuffer(i), indexing); + outputBuffers.push_back(genericOp.getOutputBuffer(i)); } + inlineRegionAndEmitStdStore(genericOp, indexedValues, indexing, + outputBuffers); } }; @@ -397,25 +416,17 @@ // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { Value input = indexedGenericOp.getInput(i); - if (input.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); - indexedValues[nLoops + i] = std_load(input, indexing); - } else { - indexedValues[nLoops + i] = std_load(input); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); + indexedValues[nLoops + i] = std_load(input, indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { Value output = indexedGenericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nLoops + nInputs + i] = std_load(output, indexing); - } else { - indexedValues[nLoops + nInputs + i] = std_load(output); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + indexedValues[nLoops + nInputs + i] = std_load(output, indexing); } if (auto funcOp = indexedGenericOp.getFunction()) { @@ -426,40 +437,24 @@ // 3. Emit std_store. for (unsigned i = 0; i < nOutputs; ++i) { Value output = indexedGenericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), output, indexing); - } else { - std_store(callOp->getResult(i), output); - } + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + std_store(callOp->getResult(i), output, indexing); } return; } // TODO(ntv): When a region inliner exists, use it. // 2. Inline region, currently only works for a single basic block. - BlockAndValueMapping map; - auto &block = indexedGenericOp.region().front(); - map.map(block.getArguments(), indexedValues); - for (auto &op : block.without_terminator()) { - assert(op.getNumRegions() == 0); - auto *newOp = b.clone(op, map); - map.map(op.getResults(), newOp->getResults()); - } - // 3. Emit std_store. - auto *yieldOp = cast(block.back()).getOperation(); - assert(yieldOp->getNumOperands() == nOutputs); + SmallVector indexing; + SmallVector outputBuffers; for (unsigned i = 0; i < nOutputs; ++i) { - Value output = indexedGenericOp.getOutputBuffer(i); - if (output.getType().cast().getRank()) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), output, indexing); - } else { - std_store(map.lookup(yieldOp->getOperand(i)), output); - } + indexing.emplace_back(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); } + inlineRegionAndEmitStdStore(indexedGenericOp, indexedValues, indexing, + outputBuffers); } };