diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -1,3 +1,8 @@ +ods_def: +def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { + C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); +} + ods_def: def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -225,36 +225,6 @@ let hasFolder = 1; } -def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> { - - let arguments = (ins AnyStridedMemRefOfRank<2>, - AnyStridedMemRefOfRank<2>, - AnyStridedMemRefOfRank<2>); - - let extraClassDeclaration = libraryCallName # [{ - llvm::Optional> referenceIterators() { - return SmallVector{ - getParallelIteratorTypeName(), - getParallelIteratorTypeName(), - getReductionIteratorTypeName()}; - } - - // A(i, r_k) * B(r_k, j) -> C(i, j) - llvm::Optional> referenceIndexingMaps() { - MLIRContext *context = getContext(); - AffineExpr i, j, r_k; - bindDims(context, i, j, r_k); - return SmallVector{ - AffineMap::get(3, 0, {i, r_k}, context), - AffineMap::get(3, 0, {r_k, j},context), - AffineMap::get(3, 0, {i, j}, context) - }; - } - }]; - - let hasFolder = 1; -} - /// A base class for pooling operation such as conv. The arguments must contain /// optional arguments `strides`, `dilations` and `padding` with following type: /// OptionalAttr:$strides diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -165,19 +165,16 @@ void vectorizeLinalgOp(OpBuilder &builder, Operation *op); /// Emits a loop nest of `LoopTy` with the proper body for `op`. -template +template Optional linalgLowerOpToLoops(OpBuilder &builder, Operation *op); /// Emits a loop nest of `scf.for` with the proper body for `op`. -template LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); /// Emits a loop nest of `scf.parallel` with the proper body for `op`. -template LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); /// Emits a loop nest of `affine.for` with the proper body for `op`. -template LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); //===----------------------------------------------------------------------===// @@ -419,12 +416,12 @@ // TODO: Move lowering to library calls here. return failure(); } else if (loweringType == LinalgLoweringType::Loops) { - if (failed(linalgOpToLoops(rewriter, op))) + if (failed(linalgOpToLoops(rewriter, op))) return failure(); } else if (loweringType == LinalgLoweringType::AffineLoops) { - if (failed(linalgOpToAffineLoops(rewriter, op))) + if (failed(linalgOpToAffineLoops(rewriter, op))) return failure(); - } else if (failed(linalgOpToParallelLoops(rewriter, op))) { + } else if (failed(linalgOpToParallelLoops(rewriter, op))) { return failure(); } rewriter.eraseOp(op); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -241,8 +241,11 @@ LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, LinalgOpConversion>(ctx); + // TODO: collect all auto-generated named ops with a tblgen directive. + patterns.insert< + LinalgOpConversion, + LinalgOpConversion>(ctx); // clang-format on } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1128,10 +1128,6 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } -LogicalResult MatmulOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} OpFoldResult ReshapeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); @@ -1193,7 +1189,7 @@ p << op.getOperationName() << ' '; p.printOptionalAttrDict(op.getAttrs(), silentAttrNames); p << ' ' << op.getOperands(); - p << ": (" << op.getOperandTypes() << ")"; + p << " : (" << op.getOperandTypes() << ")"; auto outputTensorTypes = op.getResultTypes(); if (!outputTensorTypes.empty()) p << " -> (" << outputTensorTypes << ")"; @@ -1205,8 +1201,8 @@ SmallVector operandsInfo; // Optional attributes may be added. - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseOperandList(operandsInfo)) + if (parser.parseOperandList(operandsInfo) || + parser.parseOptionalAttrDict(result.attributes)) return failure(); SmallVector operandTypes; @@ -1242,3 +1238,7 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult MatmulOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -80,6 +80,8 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef indexedValues, ArrayRef> indexing, ArrayRef outputBuffers) { + assert(op.getOperation()->getNumRegions() == 1 && + "Expected single region op"); auto &b = ScopedContext::getBuilderRef(); auto &block = op.region().front(); BlockAndValueMapping map; @@ -150,276 +152,224 @@ /// } /// } /// ``` -template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - LinalgOpType linalgOp) { - assert(linalgOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto &b = ScopedContext::getBuilderRef(); - auto loc = ScopedContext::getLocation(); - unsigned nInputs = linalgOp.getNumInputs(); - unsigned nOutputs = linalgOp.getNumOutputs(); - SmallVector indexedValues; - indexedValues.reserve(nInputs + nOutputs); - - // TODO(mravishankar): Avoid the loads if the corresponding argument of the - // region has no uses. - // 1.a. Emit load from input views. - for (unsigned i = 0; i < nInputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getInputIndexingMap(i), allIvs); - // Passing through IndexedValueType emits the proper load operation. - indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); - } - // 1.b. Emit load from output views. - for (unsigned i = 0; i < nOutputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, linalgOp.getOutputIndexingMap(i), allIvs); - // Passing through IndexedValueType emits the proper load operation. - indexedValues.push_back( - IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); - } +// TODO: need a LinalgStructuredOpInterface. +template +void emitScalarImplementation(ArrayRef allIvs, + LinalgStructuredOpType linalgOp) { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + auto &b = ScopedContext::getBuilderRef(); + auto loc = ScopedContext::getLocation(); + unsigned nInputs = linalgOp.getNumInputs(); + unsigned nOutputs = linalgOp.getNumOutputs(); + SmallVector indexedValues; + indexedValues.reserve(nInputs + nOutputs); + + // TODO(mravishankar): Avoid the loads if the corresponding argument of the + // region has no uses. + // 1.a. Emit load from input views. + for (unsigned i = 0; i < nInputs; ++i) { + auto indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getInputIndexingMap(i), allIvs); + // Passing through IndexedValueType emits the proper load operation. + indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); + } + // 1.b. Emit load from output views. + for (unsigned i = 0; i < nOutputs; ++i) { + auto indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getOutputIndexingMap(i), allIvs); + // Passing through IndexedValueType emits the proper load operation. + indexedValues.push_back( + IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); + } - // TODO(ntv): When a region inliner exists, use it. - // 2. Inline region, currently only works for a single basic block. - // 3. Emit store. - SmallVector, 8> indexing; - SmallVector outputBuffers; - for (unsigned i = 0; i < nOutputs; ++i) { - indexing.push_back(makeCanonicalAffineApplies( - b, loc, linalgOp.getOutputIndexingMap(i), allIvs)); - outputBuffers.push_back(linalgOp.getOutputBuffer(i)); - } - inlineRegionAndEmitStore(linalgOp, indexedValues, - indexing, outputBuffers); + // TODO(ntv): When a region inliner exists, use it. + // 2. Inline region, currently only works for a single basic block. + // 3. Emit store. + SmallVector, 8> indexing; + SmallVector outputBuffers; + for (unsigned i = 0; i < nOutputs; ++i) { + indexing.push_back(makeCanonicalAffineApplies( + b, loc, linalgOp.getOutputIndexingMap(i), allIvs)); + outputBuffers.push_back(linalgOp.getOutputBuffer(i)); } -}; + inlineRegionAndEmitStore(linalgOp, indexedValues, indexing, + outputBuffers); +} template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { - assert(copyOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto nPar = copyOp.getNumParallelLoops(); - assert(nPar == allIvs.size()); - auto inputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); - auto outputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); - SmallVector iivs(inputIvs.begin(), inputIvs.end()); - SmallVector oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); - // Emit the proper scalar assignment, whether we are dealing with a 0-D or - // an n-D loop nest; with or without permutations. - // clang-format off +void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { + assert(copyOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + auto nPar = copyOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto inputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); + auto outputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); + SmallVector iivs(inputIvs.begin(), inputIvs.end()); + SmallVector oivs(outputIvs.begin(), outputIvs.end()); + IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + // clang-format off nPar > 0 ? O(oivs) = I(iivs) : O() = I(); - // clang-format on - } -}; - -template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { - assert(fillOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto nPar = fillOp.getNumParallelLoops(); - assert(nPar == allIvs.size()); - auto ivs = SmallVector(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutputBuffer(0)); - // Emit the proper scalar assignment, whether we are dealing with a 0-D or - // an n-D loop nest; with or without permutations. - nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); - } -}; + // clang-format on +} template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { - assert(dotOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(allIvs.size() == 1); - Value r_i(allIvs[0]); - IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutputBuffer(0)); - // Emit scalar form. - C() = C() + A(r_i) * B(r_i); - } -}; +void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { + assert(fillOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + auto nPar = fillOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto ivs = SmallVector(allIvs.begin(), allIvs.begin() + nPar); + IndexedValueType O(fillOp.getOutputBuffer(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); +} template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - MatvecOp matvecOp) { - assert(matvecOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(allIvs.size() == 2); - Value i(allIvs[0]), r_j(allIvs[1]); - IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutputBuffer(0)); - // Emit scalar form. - C(i) = C(i) + A(i, r_j) * B(r_j); - } -}; - +void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { + assert(dotOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(allIvs.size() == 1); + Value r_i(allIvs[0]); + IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), + C(dotOp.getOutputBuffer(0)); + // Emit scalar form. + C() = C() + A(r_i) * B(r_i); +} template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - MatmulOp matmulOp) { - assert(matmulOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(allIvs.size() == 3); - Value i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); - IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutputBuffer(0)); - // Emit scalar form. - C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); - } -}; +void emitScalarImplementation(ArrayRef allIvs, MatvecOp matvecOp) { + assert(matvecOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(allIvs.size() == 2); + Value i(allIvs[0]), r_j(allIvs[1]); + IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), + C(matvecOp.getOutputBuffer(0)); + // Emit scalar form. + C(i) = C(i) + A(i, r_j) * B(r_j); +} template -class LinalgScopedEmitter { -public: - /// Returns the input value of convOp. If the indices in `imIdx` is out of - /// boundary, returns 0 instead. - static Value getConvOpInput(ConvOp convOp, StdIndexedValue im, - MutableArrayRef imIdx) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (!convOp.padding()) - return im(imIdx); - - auto *context = ScopedContext::getContext(); - Value zeroIndex = std_constant_index(0); - SmallVector conds; - SmallVector clampedImIdx; - for (auto iter : llvm::enumerate(imIdx)) { - int idx = iter.index(); - auto dim = iter.value(); - // Only need to iterate over the window dimensions. - if (idx == 0 || idx == static_cast(imIdx.size()) - 1) { - clampedImIdx.push_back(dim); - continue; - } - - using edsc::op::operator<; - using edsc::op::operator>=; - using edsc::op::operator||; - Value leftOutOfBound = dim < zeroIndex; - if (conds.empty()) - conds.push_back(leftOutOfBound); - else - conds.push_back(conds.back() || leftOutOfBound); - Value rightBound = std_dim(convOp.input(), idx); - conds.push_back(conds.back() || (dim >= rightBound)); - - // When padding is involved, the indices will only be shifted to negative, - // so having a max op is enough. - auto maxMap = AffineMap::get(/*dimCount=*/1, 0, - {getAffineDimExpr(/*position=*/0, context), - getAffineConstantExpr(0, context)}, - context); - clampedImIdx.push_back( - affine_max(dim.getType(), maxMap, ValueRange{dim})); +Value getConvOpInput(ConvOp convOp, StdIndexedValue im, + MutableArrayRef imIdx) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (!convOp.padding()) + return im(imIdx); + + auto *context = ScopedContext::getContext(); + Value zeroIndex = std_constant_index(0); + SmallVector conds; + SmallVector clampedImIdx; + for (auto iter : llvm::enumerate(imIdx)) { + int idx = iter.index(); + auto dim = iter.value(); + // Only need to iterate over the window dimensions. + if (idx == 0 || idx == static_cast(imIdx.size()) - 1) { + clampedImIdx.push_back(dim); + continue; } - auto &b = ScopedContext::getBuilderRef(); - Type type = convOp.input().getType().cast().getElementType(); - Value zero = std_constant(type, b.getZeroAttr(type)); - Value readInput = im(clampedImIdx); - return conds.empty() ? readInput - : (Value)std_select(conds.back(), zero, readInput); + using edsc::op::operator<; + using edsc::op::operator>=; + using edsc::op::operator||; + Value leftOutOfBound = dim < zeroIndex; + if (conds.empty()) + conds.push_back(leftOutOfBound); + else + conds.push_back(conds.back() || leftOutOfBound); + Value rightBound = std_dim(convOp.input(), idx); + conds.push_back(conds.back() || (dim >= rightBound)); + + // When padding is involved, the indices will only be shifted to negative, + // so having a max op is enough. + auto maxMap = AffineMap::get(/*dimCount=*/1, 0, + {getAffineDimExpr(/*position=*/0, context), + getAffineConstantExpr(0, context)}, + context); + clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim})); } - /// Returns true is `convOp` has a non-zero padding. - static bool hasPadding(ConvOp convOp) { - for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { - if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) - return true; - } - return false; - } + auto &b = ScopedContext::getBuilderRef(); + Type type = convOp.input().getType().cast().getElementType(); + Value zero = std_constant(type, b.getZeroAttr(type)); + Value readInput = im(clampedImIdx); + return conds.empty() ? readInput + : (Value)std_select(conds.back(), zero, readInput); +} - static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { - assert(convOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto &b = ScopedContext::getBuilderRef(); - auto loc = ScopedContext::getLocation(); - auto mapsRange = convOp.indexing_maps().getAsRange(); - auto maps = llvm::to_vector<8>(llvm::map_range( - mapsRange, [](AffineMapAttr a) { return a.getValue(); })); - SmallVector fIdx( - makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); - SmallVector imIdx( - makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); - SmallVector oIdx( - makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); - - IndexedValueType F(convOp.filter()), O(convOp.output()); - - // Emit scalar form. Padded conv involves an affine.max in the memory access - // which is not allowed by affine.load. Override to use an StdIndexedValue - // when there is non-zero padding. - if (hasPadding(convOp)) { - StdIndexedValue I(convOp.input()); - Value paddedInput = getConvOpInput(convOp, I, imIdx); - O(oIdx) += F(fIdx) * paddedInput; - } else { - IndexedValueType I(convOp.input()); - O(oIdx) += F(fIdx) * I(imIdx); - } +/// Returns true is `convOp` has a non-zero padding. +static bool hasPadding(ConvOp convOp) { + for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) { + if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0) + return true; } -}; + return false; +} template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - PoolingMaxOp op) { - auto indices = getInputAndOutputIndices(allIvs, op); - // Emit scalar form. - Value lhs = std_load(op.output(), indices.outputs); - Value rhs = std_load(op.input(), indices.inputs); - using edsc::op::operator>; - Value maxValue = std_select(lhs > rhs, lhs, rhs); - std_store(maxValue, op.output(), indices.outputs); +static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + auto &b = ScopedContext::getBuilderRef(); + auto loc = ScopedContext::getLocation(); + auto mapsRange = convOp.indexing_maps().getAsRange(); + auto maps = llvm::to_vector<8>( + llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); + SmallVector fIdx( + makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); + SmallVector imIdx( + makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); + SmallVector oIdx( + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); + + IndexedValueType F(convOp.filter()), O(convOp.output()); + + // Emit scalar form. Padded conv involves an affine.max in the memory access + // which is not allowed by affine.load. Override to use an StdIndexedValue + // when there is non-zero padding. + if (hasPadding(convOp)) { + StdIndexedValue I(convOp.input()); + Value paddedInput = getConvOpInput(convOp, I, imIdx); + O(oIdx) += F(fIdx) * paddedInput; + } else { + IndexedValueType I(convOp.input()); + O(oIdx) += F(fIdx) * I(imIdx); } -}; +} template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - PoolingMinOp op) { - auto indices = getInputAndOutputIndices(allIvs, op); - // Emit scalar form. - Value lhs = std_load(op.output(), indices.outputs); - Value rhs = std_load(op.input(), indices.inputs); - using edsc::op::operator<; - Value minValue = std_select(lhs < rhs, lhs, rhs); - std_store(minValue, op.output(), indices.outputs); - } -}; - +void emitScalarImplementation(ArrayRef allIvs, PoolingMaxOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + // Emit scalar form. + Value lhs = std_load(op.output(), indices.outputs); + Value rhs = std_load(op.input(), indices.inputs); + using edsc::op::operator>; + Value maxValue = std_select(lhs > rhs, lhs, rhs); + std_store(maxValue, op.output(), indices.outputs); +} template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - PoolingSumOp op) { - auto indices = getInputAndOutputIndices(allIvs, op); - IndexedValueType input(op.input()), output(op.output()); - - // Emit scalar form. - output(indices.outputs) += input(indices.inputs); - } -}; +void emitScalarImplementation(ArrayRef allIvs, PoolingMinOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + // Emit scalar form. + Value lhs = std_load(op.output(), indices.outputs); + Value rhs = std_load(op.input(), indices.inputs); + using edsc::op::operator<; + Value minValue = std_select(lhs < rhs, lhs, rhs); + std_store(minValue, op.output(), indices.outputs); +} +template +void emitScalarImplementation(ArrayRef allIvs, PoolingSumOp op) { + auto indices = getInputAndOutputIndices(allIvs, op); + IndexedValueType input(op.input()), output(op.output()); + // Emit scalar form. + output(indices.outputs) += input(indices.inputs); +} /// Emits the MLIR for the scalar part of the indexed generic op by: /// 1. Emitting load ops for each input and output view in order. This is /// achieved by applying the appropriate input or output map to the @@ -451,55 +401,52 @@ /// } /// ``` template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - IndexedGenericOp indexedGenericOp) { - assert(indexedGenericOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto &b = ScopedContext::getBuilderRef(); - auto loc = ScopedContext::getLocation(); - unsigned nInputs = indexedGenericOp.getNumInputs(); - unsigned nOutputs = indexedGenericOp.getNumOutputs(); - unsigned nLoops = allIvs.size(); - SmallVector indexedValues; - indexedValues.reserve(nLoops + nInputs + nOutputs); - for (unsigned i = 0; i < nLoops; ++i) - indexedValues.push_back(allIvs[i]); - - // TODO(mravishankar): Avoid the loads if the corresponding argument of the - // region has no uses. - // 1.a. Emit load from input views. - for (unsigned i = 0; i < nInputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); - // Pass input i through IndexedValueType emits the proper load operation. - indexedValues.push_back( - IndexedValueType(indexedGenericOp.getInput(i))(indexing)); - } - // 1.b. Emit load from output views. - for (unsigned i = 0; i < nOutputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); - // Pass output i through IndexedValueType emits the proper load operation. - indexedValues.push_back( - IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); - } +static void emitScalarImplementation(ArrayRef allIvs, + IndexedGenericOp indexedGenericOp) { + assert(indexedGenericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + auto &b = ScopedContext::getBuilderRef(); + auto loc = ScopedContext::getLocation(); + unsigned nInputs = indexedGenericOp.getNumInputs(); + unsigned nOutputs = indexedGenericOp.getNumOutputs(); + unsigned nLoops = allIvs.size(); + SmallVector indexedValues; + indexedValues.reserve(nLoops + nInputs + nOutputs); + for (unsigned i = 0; i < nLoops; ++i) + indexedValues.push_back(allIvs[i]); + + // TODO(mravishankar): Avoid the loads if the corresponding argument of the + // region has no uses. + // 1.a. Emit load from input views. + for (unsigned i = 0; i < nInputs; ++i) { + auto indexing = makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); + // Pass input i through IndexedValueType emits the proper load operation. + indexedValues.push_back( + IndexedValueType(indexedGenericOp.getInput(i))(indexing)); + } + // 1.b. Emit load from output views. + for (unsigned i = 0; i < nOutputs; ++i) { + auto indexing = makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); + // Pass output i through IndexedValueType emits the proper load operation. + indexedValues.push_back( + IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); + } - // TODO(ntv): When a region inliner exists, use it. - // 2. Inline region, currently only works for a single basic block. - // 3. Emit store. - SmallVector, 8> indexing; - SmallVector outputBuffers; - for (unsigned i = 0; i < nOutputs; ++i) { - indexing.push_back(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); - } - inlineRegionAndEmitStore(indexedGenericOp, indexedValues, - indexing, outputBuffers); + // TODO(ntv): When a region inliner exists, use it. + // 2. Inline region, currently only works for a single basic block. + // 3. Emit store. + SmallVector, 8> indexing; + SmallVector outputBuffers; + for (unsigned i = 0; i < nOutputs; ++i) { + indexing.push_back(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); } -}; + inlineRegionAndEmitStore(indexedGenericOp, indexedValues, + indexing, outputBuffers); +} template Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { @@ -524,8 +471,7 @@ if (!invertedMap) return {}; if (invertedMap.isEmpty()) { - LinalgScopedEmitter::emitScalarImplementation( - {}, linalgOp); + emitScalarImplementation({}, linalgOp); return LinalgLoops(); } @@ -537,9 +483,7 @@ GenerateLoopNest::doit( allIvs, loopRanges, linalgOp.iterator_types().getValue(), [&] { SmallVector allIvValues(allIvs.begin(), allIvs.end()); - LinalgScopedEmitter::emitScalarImplementation(allIvValues, - linalgOp); + emitScalarImplementation(allIvValues, linalgOp); }); // Number of loop ops might be different from the number of ivs since some // loops like affine.parallel and scf.parallel have multiple ivs. @@ -573,32 +517,15 @@ } }; -/// Helper classes for type list expansion. -template -class RewritePatternList; - -template -class RewritePatternList { -public: - static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} -}; - -template -class RewritePatternList { -public: - static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert>(ctx); - RewritePatternList::build(patterns, ctx); - } -}; +template +void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert>(ctx); +} -/// Populate the given list with patterns that convert from Linalg to loops. -template -void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { - RewritePatternList::build(patterns, ctx); +template +void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { + (void)std::initializer_list{ + 0, (insertOnePattern(patterns, ctx), 0)...}; } /// Local folding pattern for AffineApplyOp that we can apply greedily. @@ -640,17 +567,21 @@ } // namespace template -static void lowerLinalgToLoopsImpl(Operation *op, MLIRContext *context) { +static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { OwningRewritePatternList patterns; // Canonicalization and folding patterns applied greedily allow cleaning up // the emitted IR on the fly. // TODO(ntv) fold view and subview ops? - FillRewritePatterns(patterns, context); + insertPatterns(patterns, context); + DimOp::getCanonicalizationPatterns(patterns, context); AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); // Just apply the patterns greedily. - applyPatternsAndFoldGreedily(op, patterns); + applyPatternsAndFoldGreedily(funcOp, patterns); } namespace { @@ -687,60 +618,74 @@ return std::make_unique(); } +// TODO: gradually remove this layer as more ops become "named". +template +Optional linalgOpToLoopsImplSwitch(Operation *op, + OpBuilder &builder) { + assert(isa(op) && "LinalgOp expected"); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + + // TODO: Cases below are generic and need a LinalgStructuredOpInterface. + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); +} + /// Emits a loop nest with the proper body for `op`. -template +template Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op) { - return linalgOpToLoopsImpl(op, builder); + return linalgOpToLoopsImplSwitch(op, builder); } -/// Emits a loop nest of `scf.for` with the proper body for `op`. -template -LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { - Optional loops = - linalgLowerOpToLoops(builder, op); - return loops ? success() : failure(); -} +template Optional +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); +template Optional +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); +template Optional +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); /// Emits a loop nest of `affine.for` with the proper body for `op`. -template LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op) { - Optional loops = - linalgLowerOpToLoops(builder, op); + Optional loops = linalgLowerOpToLoops(builder, op); + return loops ? success() : failure(); +} + +/// Emits a loop nest of `scf.for` with the proper body for `op`. +LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { + Optional loops = linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `scf.parallel` with the proper body for `op`. -template LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(builder, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } - -// TODO Need to make these instantiations more future-proof to avoid the need to -// update as soon as we add new ops. -#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ - template LogicalResult mlir::linalg::linalgOpToLoops( \ - OpBuilder & builder, Operation * op); \ - template LogicalResult mlir::linalg::linalgOpToAffineLoops( \ - OpBuilder & builder, Operation * op); \ - template LogicalResult mlir::linalg::linalgOpToParallelLoops( \ - OpBuilder & builder, Operation * op); \ - template Optional \ - mlir::linalg::linalgLowerOpToLoops( \ - OpBuilder & builder, Operation * op); - -INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMaxOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) -INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir --- a/mlir/test/Dialect/Linalg/affine.mlir +++ b/mlir/test/Dialect/Linalg/affine.mlir @@ -15,7 +15,7 @@ %A = view %arg0[%c0][%M, %K] : memref to memref %B = view %arg0[%c0][%K, %N] : memref to memref %C = view %arg0[%c0][%M, %N] : memref to memref - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul %A, %B, %C : (memref, memref, memref) return } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -14,8 +14,8 @@ // CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref %4 = linalg.slice %3[%r0, %r0] : memref, !linalg.range, !linalg.range, memref - // CHECK: linalg.matmul{{.*}}: memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32> - linalg.matmul(%3, %3, %3) : memref, memref, memref + // CHECK: linalg.matmul{{.*}}: (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) + linalg.matmul %3, %3, %3 : (memref, memref, memref) return %4: memref } diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir --- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir +++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir @@ -12,7 +12,7 @@ %0 = dim %C, %c0 : memref %1 = dim %C, %c1 : memref %2 = dim %D, %c1 : memref - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul %A, %B, %C : (memref, memref, memref) scf.for %arg5 = %c0 to %0 step %c20 { scf.for %arg6 = %c0 to %2 step %c30 { scf.for %arg7 = %c0 to %1 step %c40 { @@ -28,7 +28,7 @@ %14 = std.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref to memref %16 = std.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref to memref %17 = std.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref to memref - linalg.matmul(%14, %16, %17) : memref, memref, memref + linalg.matmul %14, %16, %17 : (memref, memref, memref) } } } diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -14,10 +14,10 @@ %0 = dim %A, %c0 : memref %1 = dim %A, %c1 : memref %2 = dim %B, %c1 : memref - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { @@ -30,10 +30,10 @@ %8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : - memref, - memref, - memref + linalg.matmul %5, %7, %8 : + (memref, + memref, + memref) } } } @@ -61,10 +61,10 @@ %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) %0 = dim %C, %c0 : memref %1 = dim %C, %c1 : memref %2 = dim %D, %c1 : memref @@ -80,10 +80,10 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : - memref, - memref, - memref + linalg.matmul %5, %7, %8 : + (memref, + memref, + memref) } } } @@ -113,10 +113,10 @@ %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) %0 = dim %D, %c0 : memref %1 = dim %D, %c1 : memref %2 = dim %C, %c1 : memref @@ -132,10 +132,10 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : - memref, - memref, - memref + linalg.matmul %5, %7, %8 : + (memref, + memref, + memref) } } } @@ -165,14 +165,14 @@ %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref - linalg.matmul(%A, %B, %D) : - memref, - memref, - memref + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) + linalg.matmul %A, %B, %D : + (memref, + memref, + memref) %0 = dim %C, %c0 : memref %1 = dim %C, %c1 : memref %2 = dim %D, %c1 : memref @@ -188,10 +188,10 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : - memref, - memref, - memref + linalg.matmul %5, %7, %8 : + (memref, + memref, + memref) } } } @@ -227,14 +227,14 @@ %0 = dim %B, %c1 : memref %1 = dim %D, %c0 : memref %2 = dim %D, %c1 : memref - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref - linalg.matmul(%C, %B, %D) : - memref, - memref, - memref + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) + linalg.matmul %C, %B, %D : + (memref, + memref, + memref) scf.for %arg5 = %c0 to %1 step %c2 { scf.for %arg6 = %c0 to %0 step %c3 { scf.for %arg7 = %c0 to %2 step %c4 { @@ -247,10 +247,10 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : - memref, - memref, - memref + linalg.matmul %5, %7, %8 : + (memref, + memref, + memref) } } } @@ -275,9 +275,9 @@ // CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}] // CHECK-DAG: %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}] // CHECK-DAG: %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}] -// CHECK: linalg.matmul(%[[A_I0]], %[[B_00]], %[[C_I0_]]) -// CHECK: linalg.matmul(%[[C_I0]], %[[B_0K]], %[[D_IK_]]) -// CHECK: linalg.matmul(%[[D_IK]], %[[B_KJ]], %[[E_IJ]]) +// CHECK: linalg.matmul %[[A_I0]], %[[B_00]], %[[C_I0_]] +// CHECK: linalg.matmul %[[C_I0]], %[[B_0K]], %[[D_IK_]] +// CHECK: linalg.matmul %[[D_IK]], %[[B_KJ]], %[[E_IJ]] // ----- @@ -297,14 +297,14 @@ %c3 = constant 3 : index %c2 = constant 2 : index %0 = dim %C, %c1 : memref - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref - linalg.matmul(%A, %C, %E) : - memref, - memref, - memref + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) + linalg.matmul %A, %C, %E : + (memref, + memref, + memref) %1 = dim %C, %c0 : memref %2 = dim %D, %c1 : memref scf.for %arg5 = %c0 to %1 step %c2 { @@ -322,10 +322,10 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : - memref, - memref, - memref + linalg.matmul %5, %7, %8 : + (memref, + memref, + memref) } } } @@ -359,14 +359,14 @@ %2 = dim %C, %c1 : memref %3 = dim %C, %c0 : memref %4 = dim %D, %c1 : memref - linalg.matmul(%A, %C, %E) : - memref, - memref, - memref - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref + linalg.matmul %A, %C, %E : + (memref, + memref, + memref) + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { @@ -379,10 +379,10 @@ %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%7, %9, %10) : - memref, - memref, - memref + linalg.matmul %7, %9, %10 : + (memref, + memref, + memref) } } } @@ -398,10 +398,10 @@ %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%7, %9, %10) : - memref, - memref, - memref + linalg.matmul %7, %9, %10 : + (memref, + memref, + memref) } } } @@ -414,7 +414,7 @@ // CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref // CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref // CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref -// CHECK: linalg.matmul(%[[A]], %[[C]], %[[E]]) +// CHECK: linalg.matmul %[[A]], %[[C]], %[[E]] // CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} { @@ -445,14 +445,14 @@ %c2 = constant 2 : index %0 = dim %A, %c0 : memref %1 = dim %A, %c1 : memref - linalg.matmul(%A, %C, %D) : - memref, - memref, - memref - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref + linalg.matmul %A, %C, %D : + (memref, + memref, + memref) + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) %2 = dim %D, %c1 : memref scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { @@ -469,10 +469,10 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : - memref, - memref, - memref + linalg.matmul %5, %7, %8 : + (memref, + memref, + memref) } } } @@ -742,10 +742,10 @@ %B = alloca(%dim, %dim)[%s0, %s1] : memref %C = alloc(%dim, %dim)[%s0, %s1] : memref - linalg.matmul(%A, %B, %C) : - memref, - memref, - memref + linalg.matmul %A, %B, %C : + (memref, + memref, + memref) scf.for %i = %c0 to %dim step %c2 { scf.for %j = %c0 to %dim step %c3 { @@ -759,10 +759,10 @@ %2 = std.subview %C[%i, %j][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%0, %1, %2) : - memref, - memref, - memref + linalg.matmul %0, %1, %2 : + (memref, + memref, + memref) } } } diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -33,7 +33,7 @@ %A = view %arg0[%c0][%M, %K] : memref to memref %B = view %arg0[%c0][%K, %N] : memref to memref %C = view %arg0[%c0][%M, %N] : memref to memref - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul %A, %B, %C : (memref, memref, memref) return } // CHECKLOOP-LABEL: func @matmul(%{{.*}}: memref, diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -26,7 +26,10 @@ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref - linalg.matmul(%11, %14, %17) : memref, memref, memref + linalg.matmul %11, %14, %17 : + (memref, + memref, + memref) } } } @@ -60,9 +63,14 @@ // CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref // CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref // -// CHECK: linalg.matmul(%[[partialA]], %[[partialB]], %[[partialC]]) : memref, memref, memref +// CHECK: linalg.matmul %[[partialA]], %[[partialB]], %[[partialC]] : +// CHECK: memref, +// CHECK: memref, +// CHECK: memref // -// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref, memref +// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : +// CHECK: memref, +// CHECK: memref // // CHECK: dealloc %[[tmpA]] : memref<32xi8> // CHECK: dealloc %[[tmpB]] : memref<48xi8> @@ -88,7 +96,10 @@ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref - linalg.matmul(%11, %14, %17) : memref, memref, memref + linalg.matmul %11, %14, %17 : + (memref, + memref, + memref) } } } @@ -122,72 +133,15 @@ // CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref, memref // CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref, memref // -// CHECK: linalg.matmul(%[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]]) : memref, memref, memref +// CHECK: linalg.matmul %[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]] : +// CHECK: memref, +// CHECK: memref, +// CHECK: memref // -// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref, memref +// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : +// CHECK: memref, +// CHECK: memref // // CHECK: dealloc %[[tmpA_f64]] : memref<64xi8> // CHECK: dealloc %[[tmpB_f64]] : memref<96xi8> // CHECK: dealloc %[[tmpC_f64]] : memref<48xi8> - -// ----- - -func @matmul_i32(%A: memref, %M: index, %N: index, %K: index) { - %c4 = constant 4 : index - %c3 = constant 3 : index - %c2 = constant 2 : index - %c0 = constant 0 : index - %c1 = constant 1 : index - %3 = view %A[%c0][%M, %K] : memref to memref - %4 = view %A[%c0][%K, %N] : memref to memref - %5 = view %A[%c0][%M, %N] : memref to memref - %6 = dim %3, %c0 : memref - %7 = dim %3, %c1 : memref - %8 = dim %4, %c1 : memref - scf.for %arg4 = %c0 to %6 step %c2 { - scf.for %arg5 = %c0 to %8 step %c3 { - scf.for %arg6 = %c0 to %7 step %c4 { - %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref - %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref - %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref - linalg.matmul(%11, %14, %17) : memref, memref, memref - } - } - } - return -} - -// CHECK-LABEL: func @matmul_i32(%{{.*}}: memref, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: %[[vA_i32:.*]] = subview {{.*}} : memref -// CHECK: %[[vB_i32:.*]] = subview {{.*}} : memref -// CHECK: %[[vC_i32:.*]] = subview {{.*}} : memref -/// -// CHECK: %[[tmpA_i32:.*]] = alloc() : memref<32xi8> -// CHECK: %[[fullA_i32:.*]] = std.view %[[tmpA_i32]][{{.*}}][{{.*}}] : memref<32xi8> to memref -// DYNAMIC: std.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref -// CHECK: %[[partialA_i32:.*]] = subview %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref to memref -/// -// CHECK: %[[tmpB_i32:.*]] = alloc() : memref<48xi8> -// CHECK: %[[fullB_i32:.*]] = std.view %[[tmpB_i32]][{{.*}}][{{.*}}] : memref<48xi8> to memref -// DYNAMIC: std.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref -// CHECK: %[[partialB_i32:.*]] = subview %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref to memref -/// -// CHECK: %[[tmpC_i32:.*]] = alloc() : memref<24xi8> -// CHECK: %[[fullC_i32:.*]] = std.view %[[tmpC_i32]][{{.*}}][{{.*}}] : memref<24xi8> to memref -// DYNAMIC: std.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref -// CHECK: %[[partialC_i32:.*]] = subview %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref to memref - -// CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref, memref -// CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref, memref -// CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref, memref -// -// CHECK: linalg.matmul(%[[partialA_i32]], %[[partialB_i32]], %[[partialC_i32]]) : memref, memref, memref -// -// CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref, memref -// -// CHECK: dealloc %[[tmpA_i32]] : memref<32xi8> -// CHECK: dealloc %[[tmpB_i32]] : memref<48xi8> -// CHECK: dealloc %[[tmpC_i32]] : memref<24xi8> diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir --- a/mlir/test/Dialect/Linalg/promotion_options.mlir +++ b/mlir/test/Dialect/Linalg/promotion_options.mlir @@ -2,8 +2,8 @@ func @gemm(%a : memref, %b : memref, %c : memref) { - linalg.matmul(%a, %b, %c) {__internal_linalg_transform__ = "START"} - : memref, memref, memref + linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "START"} + : (memref, memref, memref) return } @@ -26,7 +26,7 @@ // CHECK: linalg.copy(%[[T7]], %[[T19]]) // CHECK: linalg.fill(%[[T21]], %[[C42]]) // CHECK: linalg.copy(%[[T17]], %[[T21]]) -// CHECK: linalg.matmul(%[[T19]], %[[T12]], %[[T21]]) +// CHECK: linalg.matmul %[[T19]], %[[T12]], %[[T21]] // CHECK-NOT: linalg.fill // CHECK: linalg.copy(%[[T21]], %[[T17]]) // CHECK: dealloc %[[T18]] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -83,9 +83,9 @@ %arg1: memref, %arg2: memref, %arg3: memref) { - linalg.matmul(%arg0, %arg0, %arg0) : memref, + linalg.matmul %arg0, %arg0, %arg0 : (memref, memref, - memref + memref) linalg.matvec(%arg0, %arg1, %arg2) : memref, memref, memref @@ -95,10 +95,10 @@ return } // CHECK-LABEL: func @ops(% -// CHECK-NEXT: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : +// CHECK-NEXT: linalg.matmul %{{.*}}, %{{.*}}, %{{.*}} : +// CHECK-SAME: (memref, // CHECK-SAME: memref, -// CHECK-SAME: memref, -// CHECK-SAME: memref +// CHECK-SAME: memref) // CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) : // CHECK-SAME: memref, // CHECK-SAME: memref, diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -20,12 +20,21 @@ // TILE-234-DAG: #[[$bound_map_3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> // TILE-234-DAG: #[[$bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> +// TILE-2-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 10)> +// TILE-02-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 12)> +// TILE-002-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 16)> + // TILE-2-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> // TILE-02-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> // TILE-234-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> -func @matmul(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matmul(%arg0, %arg1, %arg2) : memref, memref, memref +func @matmul(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.matmul %arg0, %arg1, %arg2 : + (memref, + memref, + memref) return } // TILE-2-LABEL: func @matmul( @@ -41,7 +50,10 @@ // TILE-2: %[[szK:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localK]]] // TILE-2: %[[N:.*]] = dim %{{.*}}, %c1 : memref // TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref -// TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref +// TILE-2: linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]] : +// TILE-2: (memref, +// TILE-2: memref, +// TILE-2: memref) // TILE-02-LABEL: func @matmul( // TILE-02-DAG: %[[C0:.*]] = constant 0 : index @@ -56,7 +68,10 @@ // TILE-02: %[[localK:.*]] = dim %{{.*}}, %c1 // TILE-02: %[[szK:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localK]]] // TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref to memref -// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref, memref, memref +// TILE-02: linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] : +// TILE-02: (memref, +// TILE-02: memref, +// TILE-02: memref) // TILE-002-LABEL: func @matmul( // TILE-002-DAG: %[[C0:.*]] = constant 0 : index @@ -71,7 +86,10 @@ // TILE-002: %[[szK:.*]] = affine.min #[[$bound_map]](%[[K]])[%[[localK]]] // TILE-002: %[[N:.*]] = dim %{{.*}}, %c1 : memref // TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref -// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref +// TILE-002: linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} : +// TILE-002: (memref, +// TILE-002: memref, +// TILE-002: memref) // TILE-234-LABEL: func @matmul( // TILE-234-DAG: %[[C0:.*]] = constant 0 : index @@ -100,14 +118,22 @@ // TILE-234: %[[szN:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[localN]]] // TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref to memref // -// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref +// TILE-234: linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] : +// TILE-234: (memref, +// TILE-234: memref, +// TILE-234: memref) // When the buffer shapes are known at compile time, it is possible to avoid // the "min" in subview size computation. This test uses buffer sizes divisible // by respective tile sizes (M=10 divisble by 2, N=12 divisible by 2 and 3, // K=16 divisble by 2 and 4). -func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>, %arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>, %arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) { - linalg.matmul(%arg0, %arg1, %arg2) : memref<10x16xf32, offset: ?, strides: [?, 1]>, memref<16x12xf32, offset: ?, strides: [?, 1]>, memref<10x12xf32, offset: ?, strides: [?, 1]> +func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>, + %arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>, + %arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) { + linalg.matmul %arg0, %arg1, %arg2 : + (memref<10x16xf32, offset: ?, strides: [?, 1]>, + memref<16x12xf32, offset: ?, strides: [?, 1]>, + memref<10x12xf32, offset: ?, strides: [?, 1]>) return } // TILE-2-LABEL: func @matmul_static( @@ -118,33 +144,39 @@ // TILE-2-DAG: %[[C2:.*]] = constant 2 : index // TILE-2-DAG: %[[M:.*]] = constant 10 : index // TILE-2: scf.for %[[I:.*]] = %{{.*}} to %[[M]] step %{{.*}} { -// TILE-2: %[[MIN2:.*]] = affine.min #map2(%[[I]]) +// TILE-2: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[I]]) // TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref -// TILE-2: %[[MIN22:.*]] = affine.min #map2(%[[I]]) +// TILE-2: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]]) // TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref -// TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) +// TILE-2: linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]] // TILE-02-LABEL: func @matmul_static( // TILE-02-DAG: %[[C0:.*]] = constant 0 : index // TILE-02-DAG: %[[C2:.*]] = constant 2 : index // TILE-02-DAG: %[[N:.*]] = constant 12 : index // TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} { -// TILE-02: %[[MIN2:.*]] = affine.min #map2(%[[J]]) +// TILE-02: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[J]]) // TILE-02: %[[sBj:.*]] = subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]> -// TILE-02: %[[MIN22:.*]] = affine.min #map2(%[[J]]) +// TILE-02: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]]) // TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]> -// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref<10x16xf32, #[[$strided2D]]>, memref<16x?xf32, #[[$strided2D]]>, memref<10x?xf32, #[[$strided2D]]> +// TILE-02: linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] : +// TILE-02: (memref<10x16xf32, #[[$strided2D]]>, +// TILE-02: memref<16x?xf32, #[[$strided2D]]>, +// TILE-02: memref<10x?xf32, #[[$strided2D]]>) // TILE-002-LABEL: func @matmul_static( // TILE-002-DAG: %[[C0:.*]] = constant 0 : index // TILE-002-DAG: %[[C2:.*]] = constant 2 : index // TILE-002-DAG: %[[C16:.*]] = constant 16 : index // TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} { -// TILE-002: %[[MIN2:.*]] = affine.min #map2(%[[K]]) +// TILE-002: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[K]]) // TILE-002: %[[sAj:.*]] = subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]> -// TILE-002: %[[MIN22:.*]] = affine.min #map2(%[[K]]) +// TILE-002: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]]) // TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref -// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref<10x?xf32, #[[$strided2D]]>, memref, memref<10x12xf32, #[[$strided2D]]> +// TILE-002: linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} : +// TILE-002: (memref<10x?xf32, #[[$strided2D]]>, +// TILE-002: memref, +// TILE-002: memref<10x12xf32, #[[$strided2D]]>) // TILE-234-LABEL: func @matmul_static( // TILE-234-DAG: %[[C0:.*]] = constant 0 : index @@ -161,7 +193,10 @@ // TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref // TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref // -// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref +// TILE-234: linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] : +// TILE-234: (memref, +// TILE-234: memref, +// TILE-234: memref) func @matvec(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.matvec(%arg0, %arg1, %arg2) : memref, memref, memref diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir --- a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir +++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir @@ -6,8 +6,8 @@ %arg1 : memref, %arg2 : memref) { - linalg.matmul(%arg0, %arg1, %arg2) - : memref, memref, memref + linalg.matmul %arg0, %arg1, %arg2 + : (memref, memref, memref) return } // CHECK-LABEL: func @gemm @@ -21,7 +21,7 @@ // CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]] // CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]] // CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] -// CHECK: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]]) +// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] // TILE1-LABEL: func @gemm // TILE1-DAG: %[[C2:.*]] = constant 2 : index @@ -30,7 +30,7 @@ // TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0] // TILE1: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0] // TILE1-NOT: subview -// TILE1: linalg.matmul(%[[SV1]], %{{.*}}, %[[SV3]]) +// TILE1: linalg.matmul %[[SV1]], %{{.*}}, %[[SV3]] // TILE2-LABEL: func @gemm // TILE2-DAG: %[[C2:.*]] = constant 2 : index @@ -40,7 +40,7 @@ // TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0] // TILE2: %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]] // TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] -// TILE2: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]]) +// TILE2: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] // ----- diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -4,10 +4,10 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { - linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} : - memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, - memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, - memref<1584x1584xf32, offset: 0, strides: [1584, 1]> + linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "START"} : + (memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) return } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -53,10 +53,10 @@ func @matmul(%A: memref, %B: memref, %C: memref) { - linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } : - memref, - memref, - memref + linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "MEM" } : + (memref, + memref, + memref) return } // CHECK-LABEL: func @matmul @@ -85,7 +85,10 @@ // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { -// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref +// CHECK: linalg.matmul {{.*}}, {{.*}}, {{.*}} : ( +// CHECK: memref, +// CHECK: memref, +// CHECK: memref) #matmul_trait = { args_in = 2, @@ -117,8 +120,8 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { - linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "VECTORIZE"} : - memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32> + linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} : + (memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>) return } // CHECK-LABEL: func @vectorization_test_2 @@ -216,10 +219,10 @@ func @matmul_perm(%A: memref, %B: memref, %C: memref) { - linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} : - memref, - memref, - memref + linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "__with_perm__"} : + (memref, + memref, + memref) return } // CHECK-LABEL: func @matmul_perm @@ -242,7 +245,10 @@ // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { -// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref +// CHECK: linalg.matmul {{.*}}, {{.*}}, {{.*}} : ( +// CHECK: memref, +// CHECK: memref, +// CHECK: memref) func @promote_subview_matmul(%arg0: memref, %arg1: memref, @@ -264,10 +270,10 @@ memref to memref %5 = subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : memref to memref - linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_views_"} : - memref, - memref, - memref + linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_views_"} : + (memref, + memref, + memref) } } } @@ -296,7 +302,8 @@ // CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref // CHECK: linalg.copy(%[[s1]], %[[l1]]) : memref, memref // CHECK: linalg.copy(%[[s2]], %[[l2]]) : memref, memref -// CHECK: linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref, memref, memref +// CHECK: linalg.matmul %[[v0]], %[[v1]], %[[v2]] : +// CHECK: (memref, memref, memref) func @promote_first_subview_matmul(%arg0: memref, %arg1: memref, @@ -318,10 +325,10 @@ memref to memref %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : memref to memref - linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} : - memref, - memref, - memref + linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_first_view_"} : + (memref, + memref, + memref) } } } @@ -350,7 +357,10 @@ // CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref // CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref, memref // CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref, memref^ -// CHECK: linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref, memref, memref +// CHECK: linalg.matmul %[[v0]], %[[s1]], %[[s2]] : +// CHECK: (memref, +// CHECK: memref, +// CHECK: memref) func @aligned_promote_fill(%arg0: memref) { %c2000 = constant 2000 : index @@ -377,8 +387,8 @@ func @tile_permute_parallel_loop(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matmul(%arg0, %arg1, %arg2) {__internal_linalg_transform__ = "par__with_perm__"} - : memref, memref, memref + linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "par__with_perm__"} + : (memref, memref, memref) return } // CHECK-LABEL: func @tile_permute_parallel_loop diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -83,7 +83,7 @@ %B = view %bB[%c0][%c16, %c2] : memref to memref %C = view %bC[%c0][%c2, %c2] : memref to memref - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul %A, %B, %C : (memref, memref, memref) %res = load %C[%c0, %c1] : memref dealloc %bC : memref diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1474,6 +1474,10 @@ TypeRange inputTypes, TypeRange outputTypes); static void regionBuilder(Block &block); + + std::string getLibraryCallName() {{ + return generateLibraryCallName(getOperation()); + } }]; })FMT";