diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -351,10 +351,11 @@ class NamedStructuredOpTraits : public OpTrait::TraitBase { public: - llvm::Optional> referenceIterators(); - llvm::Optional> referenceIndexingMaps(); - std::function)> - emitScalarImplementation(); + static SmallVector referenceIterators(TypeRange inputTypes, + TypeRange outputTypes); + + static SmallVector referenceIndexingMaps(TypeRange inputTypes, + TypeRange outputTypes); }; } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -33,10 +33,9 @@ /// Forward declarations. template -static void buildNamedStructuredOpRegion(Builder &builder, - OperationState &result, - TypeRange operandTypes, - TypeRange tensorResultTypes); +static void buildNamedStructuredOpRegionAndAttributes( + Builder &builder, OperationState &result, TypeRange operandTypes, + TypeRange tensorResultTypes); template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); template @@ -1085,9 +1084,10 @@ //===----------------------------------------------------------------------===// template -void buildNamedStructuredOpRegion(Builder &builder, OperationState &result, - TypeRange operandTypes, - TypeRange tensorResultTypes) { +void buildNamedStructuredOpRegionAndAttributes(Builder &builder, + OperationState &result, + TypeRange operandTypes, + TypeRange tensorResultTypes) { Region ®ion = *result.addRegion(); Block *body = new Block(); // TODO: atm all operands go through getElementTypeOrSelf, @@ -1102,12 +1102,24 @@ opBuilder.setInsertionPointToStart(®ion.front()); mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc()); NamedStructuredOpType::regionBuilder(*body); + + auto indexingMaps = builder.getAffineMapArrayAttr( + NamedStructuredOpType::referenceIndexingMaps(operandTypes, + tensorResultTypes)); + result.addAttribute(getIndexingMapsAttrName(), indexingMaps); + + auto iterators = + builder.getStrArrayAttr(NamedStructuredOpType::referenceIterators( + operandTypes, tensorResultTypes)); + result.addAttribute(getIteratorTypesAttrName(), iterators); } template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + std::array silentAttrNames{getIndexingMapsAttrName(), + getIteratorTypesAttrName()}; p << op.getOperationName() << ' '; - p.printOptionalAttrDict(op.getAttrs()); + p.printOptionalAttrDict(op.getAttrs(), silentAttrNames); p << ' ' << op.getOperands(); p << ": (" << op.getOperandTypes() << ")"; auto outputTensorTypes = op.getResultTypes(); @@ -1139,7 +1151,7 @@ if (!tensorResultTypes.empty()) result.addTypes(tensorResultTypes); - buildNamedStructuredOpRegion( + buildNamedStructuredOpRegionAndAttributes( parser.getBuilder(), result, operandTypes, tensorResultTypes); return parser.resolveOperands(operandsInfo, operandTypes, diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -78,11 +78,10 @@ return res; } -template -static void -inlineRegionAndEmitStdStore(OpType op, ArrayRef indexedValues, - ArrayRef> indexing, - ArrayRef outputBuffers) { +template +static void inlineRegionAndEmitStore(OpType op, ArrayRef indexedValues, + ArrayRef> indexing, + ArrayRef outputBuffers) { auto &b = ScopedContext::getBuilder(); auto &block = op.region().front(); BlockAndValueMapping map; @@ -95,10 +94,10 @@ Operation &terminator = block.back(); assert(isa(terminator) && - "expected an yield op in the end of the region"); + "expected a yield op in the end of the region"); for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { - std_store(map.lookupOrDefault(terminator.getOperand(i)), outputBuffers[i], - ArrayRef{indexing[i].begin(), indexing[i].end()}); + IndexedValueType O(outputBuffers[i]); + O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); } } @@ -123,9 +122,36 @@ namespace { -// Generic loop emitter, to be specialized on an op-per op basis. -// TODO: Hook up to named ops interface and, later, retire when all named ops -// are auto-generated. +/// Emits the MLIR for the scalar part of the generic op by: +/// 1. Emitting load ops for each input and output view in order. This is +/// achieved by applying the appropriate input or output map to the +/// enclosing induction variables. +/// 2. Emitting a call to `op.fun()` that takes as arguments the scalars +/// from point 1. above. +/// 3. Emitting store ops to store the results of 2. to the output +/// views. +/// +/// An example output may resemble: +/// +/// ``` +/// loop.for %i = %c0 to %0 step %c1 { +/// loop.for %j = %c0 to %1 step %c1 { +/// loop.for %k = %c0 to %4 step %c1 { +/// %11 = load %arg0[%i, %j] : +/// memref +/// %12 = load %arg1[%i, %j, %k] : +/// memref +/// %13 = load %arg2[%i, %k, %j] : +/// memref +/// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) +/// store %14#0, %arg1[%i, %j, %k] : +/// memref +/// store %14#1, %arg2[%i, %k, %j] : +/// memref +/// } +/// } +/// } +/// ``` template class LinalgScopedEmitter { public: @@ -133,9 +159,43 @@ LinalgOpType linalgOp) { assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); - llvm_unreachable("NYI"); - linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(), - ScopedContext::getLocation(), allIvs); + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + unsigned nInputs = linalgOp.getNumInputs(); + unsigned nOutputs = linalgOp.getNumOutputs(); + SmallVector indexedValues; + indexedValues.reserve(nInputs + nOutputs); + + // TODO(mravishankar): Avoid the loads if the corresponding argument of the + // region has no uses. + // 1.a. Emit load from input views. + for (unsigned i = 0; i < nInputs; ++i) { + auto indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getInputIndexingMap(i), allIvs); + // Passing through IndexedValueType emits the proper load operation. + indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing)); + } + // 1.b. Emit load from output views. + for (unsigned i = 0; i < nOutputs; ++i) { + auto indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getOutputIndexingMap(i), allIvs); + // Passing through IndexedValueType emits the proper load operation. + indexedValues.push_back( + IndexedValueType(linalgOp.getOutputBuffer(i))(indexing)); + } + + // TODO(ntv): When a region inliner exists, use it. + // 2. Inline region, currently only works for a single basic block. + // 3. Emit store. + SmallVector, 8> indexing; + SmallVector outputBuffers; + for (unsigned i = 0; i < nOutputs; ++i) { + indexing.push_back(makeCanonicalAffineApplies( + b, loc, linalgOp.getOutputIndexingMap(i), allIvs)); + outputBuffers.push_back(linalgOp.getOutputBuffer(i)); + } + inlineRegionAndEmitStore(linalgOp, indexedValues, + indexing, outputBuffers); } }; @@ -231,7 +291,7 @@ public: /// Returns the input value of convOp. If the indices in `imIdx` is out of /// boundary, returns 0 instead. - static Value getConvOpInput(ConvOp convOp, IndexedValueType im, + static Value getConvOpInput(ConvOp convOp, StdIndexedValue im, MutableArrayRef imIdx) { // TODO(ntv): add a level of indirection to linalg.generic. if (!convOp.padding()) @@ -293,7 +353,11 @@ makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); SmallVector oIdx( makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); - IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output()); + + // Padded conv involves an affine.max in the memory access which is not + // allowed by affine.load. Override to always use an StdIndexedValue. + StdIndexedValue I(convOp.input()); + IndexedValueType F(convOp.filter()), O(convOp.output()); // Emit scalar form. Value paddedInput = getConvOpInput(convOp, I, imIdx); @@ -344,111 +408,36 @@ } }; -// Emits the MLIR for the scalar part of the generic op by: -// 1. Emitting std_load and std_store ops for each input and output -// view in order. This is achieved by applying the appropriate input or -// output map to the enclosing induction variables. -// 2. Emitting a call to `op.fun()` that takes as arguments the scalars -// from point 1. above. -// 3. Emitting std_store to store the results of 2. to the output -// views. -// -// An example output may resemble: -// -// ``` -// loop.for %i = %c0 to %0 step %c1 { -// loop.for %j = %c0 to %1 step %c1 { -// loop.for %k = %c0 to %4 step %c1 { -// %11 = load %arg0[%i, %j] : -// memref -// %12 = load %arg1[%i, %j, %k] : -// memref -// %13 = load %arg2[%i, %k, %j] : -// memref -// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) -// store %14#0, %arg1[%i, %j, %k] : -// memref -// store %14#1, %arg2[%i, %k, %j] : -// memref -// } -// } -// } -// ``` -template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - GenericOp genericOp) { - assert(genericOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto b = ScopedContext::getBuilder(); - auto loc = ScopedContext::getLocation(); - unsigned nInputs = genericOp.getNumInputs(); - unsigned nOutputs = genericOp.getNumOutputs(); - SmallVector indexedValues(nInputs + nOutputs); - - // 1.a. Emit std_load from input views. - for (unsigned i = 0; i < nInputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs); - indexedValues[i] = std_load(genericOp.getInput(i), indexing); - } - - // 1.b. Emit std_load from output views. - // TODO(mravishankar): Avoid the loads if the corresponding argument of the - // region has no uses. - for (unsigned i = 0; i < nOutputs; ++i) { - Value output = genericOp.getOutputBuffer(i); - auto indexing = makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs); - indexedValues[nInputs + i] = std_load(output, indexing); - } - - // TODO(ntv): When a region inliner exists, use it. - // 2. Inline region, currently only works for a single basic block. - // 3. Emit std_store. - SmallVector, 8> indexing; - SmallVector outputBuffers; - for (unsigned i = 0; i < nOutputs; ++i) { - indexing.push_back(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - outputBuffers.push_back(genericOp.getOutputBuffer(i)); - } - inlineRegionAndEmitStdStore(genericOp, indexedValues, indexing, - outputBuffers); - } -}; - -// Emits the MLIR for the scalar part of the indexed generic op by: -// 1. Emitting std_load and std_store ops for each input and output view in -// order. This is achieved by applying the appropriate input or output map -// to the enclosing induction variables. -// 2. Emitting a call to `op.fun()` that takes as arguments the induction -// variables and the scalars from point 1. above. -// 3. Emitting std_store to store the results of 2. to the output views. -// -// An example output may resemble: -// -// ``` -// loop.for %i = %c0 to %0 step %c1 { -// loop.for %j = %c0 to %1 step %c1 { -// loop.for %k = %c0 to %4 step %c1 { -// %11 = load %arg0[%i, %j] : -// memref -// %12 = load %arg1[%i, %j, %k] : -// memref -// %13 = load %arg2[%i, %k, %j] : -// memref -// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : -// (index, index, index, f32, f32, f32) -> (f32, f32) -// store %14#0, %arg1[%i, %j, %k] : -// memref -// store %14#1, %arg2[%i, %k, %j] : -// memref -// } -// } -// } -// ``` +/// Emits the MLIR for the scalar part of the indexed generic op by: +/// 1. Emitting load ops for each input and output view in order. This is +/// achieved by applying the appropriate input or output map to the +/// enclosing induction variables. +/// 2. Emitting a call to `op.fun()` that takes as arguments the induction +/// variables and the scalars from point 1. above. +/// 3. Emitting store ops to store the results of 2. to the output views. +/// +/// An example output may resemble: +/// +/// ``` +/// loop.for %i = %c0 to %0 step %c1 { +/// loop.for %j = %c0 to %1 step %c1 { +/// loop.for %k = %c0 to %4 step %c1 { +/// %11 = load %arg0[%i, %j] : +/// memref +/// %12 = load %arg1[%i, %j, %k] : +/// memref +/// %13 = load %arg2[%i, %k, %j] : +/// memref +/// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : +/// (index, index, index, f32, f32, f32) -> (f32, f32) +/// store %14#0, %arg1[%i, %j, %k] : +/// memref +/// store %14#1, %arg2[%i, %k, %j] : +/// memref +/// } +/// } +/// } +/// ``` template class LinalgScopedEmitter { public: @@ -461,31 +450,33 @@ unsigned nInputs = indexedGenericOp.getNumInputs(); unsigned nOutputs = indexedGenericOp.getNumOutputs(); unsigned nLoops = allIvs.size(); - SmallVector indexedValues(nLoops + nInputs + nOutputs); - - for (unsigned i = 0; i < nLoops; ++i) { - indexedValues[i] = allIvs[i]; - } + SmallVector indexedValues; + indexedValues.reserve(nLoops + nInputs + nOutputs); + for (unsigned i = 0; i < nLoops; ++i) + indexedValues.push_back(allIvs[i]); - // 1.a. Emit std_load from input views. + // TODO(mravishankar): Avoid the loads if the corresponding argument of the + // region has no uses. + // 1.a. Emit load from input views. for (unsigned i = 0; i < nInputs; ++i) { - Value input = indexedGenericOp.getInput(i); auto indexing = makeCanonicalAffineApplies( b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); - indexedValues[nLoops + i] = std_load(input, indexing); + // Pass input i through IndexedValueType emits the proper load operation. + indexedValues.push_back( + IndexedValueType(indexedGenericOp.getInput(i))(indexing)); } - - // 1.b. Emit std_load from output views. + // 1.b. Emit load from output views. for (unsigned i = 0; i < nOutputs; ++i) { - Value output = indexedGenericOp.getOutputBuffer(i); auto indexing = makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); - indexedValues[nLoops + nInputs + i] = std_load(output, indexing); + // Pass output i through IndexedValueType emits the proper load operation. + indexedValues.push_back( + IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing)); } // TODO(ntv): When a region inliner exists, use it. // 2. Inline region, currently only works for a single basic block. - // 3. Emit std_store. + // 3. Emit store. SmallVector, 8> indexing; SmallVector outputBuffers; for (unsigned i = 0; i < nOutputs; ++i) { @@ -493,19 +484,19 @@ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); } - inlineRegionAndEmitStdStore(indexedGenericOp, indexedValues, indexing, - outputBuffers); + inlineRegionAndEmitStore(indexedGenericOp, indexedValues, + indexing, outputBuffers); } }; -// This struct is for factoring out the implementation and support template -// instantiations in the following 2 cases: -// 1. Appending to a list of patterns via RewritePatternList. -// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. -// The implementation must work both in DRR and inside a RewritePattern. As a -// consequence, (1) it is only allowed to emit new ops if the match is -// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an -// encompassing pattern must take care of the erasure logic. +/// This struct is for factoring out the implementation and support template +/// instantiations in the following 2 cases: +/// 1. Appending to a list of patterns via RewritePatternList. +/// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. +/// The implementation must work both in DRR and inside a RewritePattern. As a +/// consequence, (1) it is only allowed to emit new ops if the match is +/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an +/// encompassing pattern must take care of the erasure logic. template class LinalgOpToLoopsImpl { public: @@ -532,7 +523,7 @@ } }; -/// Generates loops nest using loop.parallel. loop.parallel is only used for the +/// Generates loop nest using loop.parallel. loop.parallel is only used for the /// outer parallel loops. All other loops are generated using loop.for /// operation. template @@ -652,7 +643,7 @@ } }; -// Helper classes for type list expansion. +/// Helper classes for type list expansion. template class RewritePatternList; @@ -680,16 +671,16 @@ >::build(patterns, ctx); } -// Local folding pattern for AffineApplyOp that we can apply greedily. -// This replaces AffineApplyOp by the proper value in cases where the associated -// map is trivial. A trivial map here is defined as a map with a single result -// and either: -// 1. Zero operand + returns a single AffineConstantExpr -// 2. One operand + returns a single AffineDimExpr -// 3. One operands + returns a single AffineSymbolExpr +/// Local folding pattern for AffineApplyOp that we can apply greedily. +/// This replaces AffineApplyOp by the proper value in cases where the +/// associated map is trivial. +/// A trivial map here is defined as a map with a single result and either: +/// 1. Zero operand + returns a single AffineConstantExpr +/// 2. One operand + returns a single AffineDimExpr +/// 3. One operand + returns a single AffineSymbolExpr // -// In the first case, the AffineApplyOp is replaced by a new constant. In the -// other cases, it is replaced by its unique operand. +/// In the first case, the AffineApplyOp is replaced by a new constant. In the +/// other cases, it is replaced by its unique operand. struct FoldAffineOp : public RewritePattern { FoldAffineOp(MLIRContext *context) : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir --- a/mlir/test/Dialect/Linalg/affine.mlir +++ b/mlir/test/Dialect/Linalg/affine.mlir @@ -1,13 +1,15 @@ // RUN: mlir-opt %s -convert-linalg-to-affine-loops | FileCheck %s // Test that we can lower all the way to LLVM without crashing, don't check results here. -// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 +// RUN: mlir-opt %s -convert-linalg-to-affine-loops -convert-linalg-to-llvm -o=/dev/null 2>&1 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> // CHECK-DAG: #[[stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> +// CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)> + func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { %c0 = constant 0 : index %c1 = constant 1 : index @@ -53,3 +55,69 @@ // CHECK: affine.for %{{.*}} = 0 to %[[Q]] { // CHECK: affine.for %{{.*}} = 0 to %[[Z0]] { // CHECK: %[[SUM:.*]] = affine.apply #[[stride2Dilation1]](%{{.*}}, %{{.*}}) + +func @conv_padding(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], + padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>, + strides = [1, 1]} : + memref, memref, memref + return +} +// CHECK-LABEL: func @conv_padding +// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[Z0:.*]] = dim %arg0, 0 : memref +// CHECK: %[[Z1:.*]] = dim %arg0, 1 : memref +// CHECK: %[[Q:.*]] = dim %arg0, 2 : memref +// CHECK: %[[K:.*]] = dim %arg0, 3 : memref +// CHECK: %[[B:.*]] = dim %arg1, 0 : memref +// CHECK: %[[X0:.*]] = dim %arg2, 1 : memref +// CHECK: %[[X1:.*]] = dim %arg2, 2 : memref +// CHECK: affine.for %{{.*}} = 0 to %[[B]] { +// CHECK: affine.for %{{.*}} = 0 to %[[X0]] { +// CHECK: affine.for %{{.*}} = 0 to %[[X1]] { +// CHECK: affine.for %{{.*}} = 0 to %[[K]] { +// CHECK: affine.for %{{.*}} = 0 to %[[Q]] { +// CHECK: affine.for %{{.*}} = 0 to %[[Z0]] { +// CHECK: affine.for %{{.*}} = 0 to %[[Z1]] { +// CHECK: %[[SUM0:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}}) +// CHECK: %[[SUM1:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}}) +// CHECK: %[[IDX:.*]] = affine.max #[[clampMinMap]](%[[SUM0]]) +// CHECK: %[[IDY:.*]] = affine.max #[[clampMinMap]](%[[SUM1]]) +// Padded conv involves an affine.max in the memory access which is not +// allowed by affine.load. Override to always use an std.load. +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref +// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32 +// CHECK: %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref +// CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 +// CHECK: %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref +// CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 +// CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref + +//----------------------------------------------------------------------------// +// Named ops to loops. +//----------------------------------------------------------------------------// +func @named_batch_matmul(%A: memref, %B: memref, %C: memref) { + linalg.batch_matmul %A, %B, %C : (memref, memref, memref) -> () + return +} +// CHECK-LABEL: @named_batch_matmul +// CHECK-SAME: %[[mA:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[mB:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[mC:[a-zA-Z0-9]+]]: memref +// CHECK: %[[B:.*]] = dim %[[mA]], 0 : memref +// CHECK: %[[M:.*]] = dim %[[mA]], 1 : memref +// CHECK: %[[K:.*]] = dim %[[mA]], 2 : memref +// CHECK: %[[N:.*]] = dim %[[mB]], 2 : memref +// CHECK: affine.for %[[b:.*]] = 0 to %[[B]] { +// CHECK: affine.for %[[m:.*]] = 0 to %[[M]] { +// CHECK: affine.for %[[n:.*]] = 0 to %[[N]] { +// CHECK: affine.for %[[k:.*]] = 0 to %[[K]] { +// CHECK: %[[va:.*]] = affine.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref +// CHECK: %[[vb:.*]] = affine.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref +// CHECK: %[[vc:.*]] = affine.load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref +// CHECK: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECK: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECK: affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -2,7 +2,7 @@ // RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s // Test that we can lower all the way to LLVM without crashing, don't check results here. -// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -o=/dev/null 2>&1 // CHECKLOOP-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECKLOOP-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> @@ -354,7 +354,6 @@ // CHECKPARALLEL: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 // CHECKPARALLEL: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref - func @conv_padding(%arg0: memref, %arg1: memref, %arg2: memref) { @@ -854,8 +853,8 @@ // CHECKLOOP-SAME: %[[ARG1]]: memref // CHECKLOOP-SAME: %[[ARG2]]: memref // CHECKLOOP-NOT: loop.for -// CHECKLOOP-DAG: load %[[ARG0]][] -// CHECKLOOP-DAG: load %[[ARG1]][] +// CHECKLOOP: load %[[ARG0]][] +// CHECKLOOP: load %[[ARG1]][] // CHECKLOOP: addf // CHECKLOOP: store %{{.*}}, %[[ARG2]][] @@ -864,7 +863,50 @@ // CHECKPARALLEL-SAME: %[[ARG1]]: memref // CHECKPARALLEL-SAME: %[[ARG2]]: memref // CHECKPARALLEL-NOT: loop.for -// CHECKPARALLEL-DAG: load %[[ARG0]][] -// CHECKPARALLEL-DAG: load %[[ARG1]][] +// CHECKPARALLEL: load %[[ARG0]][] +// CHECKPARALLEL: load %[[ARG1]][] // CHECKPARALLEL: addf // CHECKPARALLEL: store %{{.*}}, %[[ARG2]][] + +//----------------------------------------------------------------------------// +// Named ops to loops. +//----------------------------------------------------------------------------// +func @named_batch_matmul(%A: memref, %B: memref, %C: memref) { + linalg.batch_matmul %A, %B, %C : (memref, memref, memref) -> () + return +} +// CHECKLOOP-LABEL: @named_batch_matmul +// CHECKLOOP-SAME: %[[mA:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[mB:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[mC:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[B:.*]] = dim %[[mA]], 0 : memref +// CHECKLOOP: %[[M:.*]] = dim %[[mA]], 1 : memref +// CHECKLOOP: %[[K:.*]] = dim %[[mA]], 2 : memref +// CHECKLOOP: %[[N:.*]] = dim %[[mB]], 2 : memref +// CHECKLOOP: loop.for %[[b:.*]] = %{{.*}} to %[[B]] step %{{.*}} { +// CHECKLOOP: loop.for %[[m:.*]] = %{{.*}} to %[[M]] step %{{.*}} { +// CHECKLOOP: loop.for %[[n:.*]] = %{{.*}} to %[[N]] step %{{.*}} { +// CHECKLOOP: loop.for %[[k:.*]] = %{{.*}} to %[[K]] step %{{.*}} { +// CHECKLOOP: %[[va:.*]] = load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref +// CHECKLOOP: %[[vb:.*]] = load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref +// CHECKLOOP: %[[vc:.*]] = load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref +// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref + +// CHECKPARALLEL-LABEL: @named_batch_matmul +// CHECKPARALLEL-SAME: %[[mA:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[mB:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[mC:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[B:.*]] = dim %[[mA]], 0 : memref +// CHECKPARALLEL: %[[M:.*]] = dim %[[mA]], 1 : memref +// CHECKPARALLEL: %[[K:.*]] = dim %[[mA]], 2 : memref +// CHECKPARALLEL: %[[N:.*]] = dim %[[mB]], 2 : memref +// CHECKPARALLEL: loop.parallel (%[[b:.*]], %[[m:.*]], %[[n:.*]]) = ({{.*}}) to (%[[B]], %[[M]], %[[N]]) step ({{.*}}) { +// CHECKPARALLEL: loop.for %[[k:.*]] = %{{.*}} to %[[K]] step %{{.*}} { +// CHECKPARALLEL: %[[va:.*]] = load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref +// CHECKPARALLEL: %[[vb:.*]] = load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -7,15 +7,15 @@ // ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: Test1Op::referenceIterators() { -// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// IMPL-LABEL: SmallVector Test1Op::referenceIterators +// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: Test1Op::referenceIndexingMaps() { +// IMPL: SmallVector Test1Op::referenceIndexingMaps // IMPL: AffineMap::get(2, 0, {d0, d1}, context), // IMPL-NEXT: AffineMap::get(2, 0, {d1}, context), // IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) }; // -// IMPL: Test1Op::regionBuilder(Block &block) { +// IMPL: void Test1Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); @@ -32,10 +32,10 @@ // ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: Test2Op::referenceIterators() { -// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// IMPL-LABEL: SmallVector Test2Op::referenceIterators +// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: Test2Op::referenceIndexingMaps() { +// IMPL: SmallVector Test2Op::referenceIndexingMaps // IMPL: AffineMap::get(3, 0, {d0, d2}, context), // IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}, context), // IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) }; @@ -57,10 +57,10 @@ // ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: Test3Op::referenceIterators() { -// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// IMPL-LABEL: SmallVector Test3Op::referenceIterators +// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: Test3Op::referenceIndexingMaps() { +// IMPL: SmallVector Test3Op::referenceIndexingMaps // IMPL: AffineMap::get(4, 0, {d0, d1, d3}, context), // IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}, context), // IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) }; diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1472,7 +1472,7 @@ [{{ result.addOperands(views); result.addTypes(outputTypes); - buildNamedStructuredOpRegion<{0}>( + buildNamedStructuredOpRegionAndAttributes<{0}>( b, result, TypeRange(views), outputTypes); }]> ]; @@ -1481,7 +1481,13 @@ }]; let extraClassDeclaration = [{{ llvm::Optional> referenceIterators(); + static SmallVector referenceIterators( + TypeRange inputTypes, TypeRange outputTypes); + llvm::Optional> referenceIndexingMaps(); + static SmallVector referenceIndexingMaps( + TypeRange inputTypes, TypeRange outputTypes); + static void regionBuilder(Block &block); }]; })FMT"; @@ -1503,7 +1509,13 @@ ComprehensionParsingState &state) { const char *referenceReferenceIteratorsFmt = R"FMT( - llvm::Optional> {0}::referenceIterators() { + // This is temporary until we transition out of manually specified ops + // that should be auto-generated with linalg-ods-gen. + llvm::Optional> {0}::referenceIterators() {{ + llvm_unreachable("Unexpected missing `iterator_types` attribute."); + } + SmallVector {0}::referenceIterators( + TypeRange inputTypes, TypeRange outputTypes) { return SmallVector{{ {1} }; })FMT"; @@ -1536,15 +1548,27 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state) { + // 1. Generic string template for specifying reference indexing maps. const char *referenceIndexingMapsFmt = R"FMT( - llvm::Optional> {0}::referenceIndexingMaps() { - MLIRContext *context = getContext(); + // This is temporary until we transition out of manually specified ops that + // should be auto-generated with linalg-ods-gen. + llvm::Optional> {0}::referenceIndexingMaps() {{ + llvm_unreachable("Unexpected missing `indexing_maps` attribute."); + } + SmallVector {0}::referenceIndexingMaps( + TypeRange inputTypes, TypeRange outputTypes) { + assert(!inputTypes.empty() && "At least one input expected"); + MLIRContext *context = (*inputTypes.begin()).getContext(); AffineExpr {1}; bindDims(context, {1}); return SmallVector{{ {2} }; })FMT"; + // 2. Print a comma-separated list of identifiers for the AffineExpr in + // `state.dims`. These will replace the `{1}` placeholder in both + // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr + // identifiers are bound in the right order to the proper AffineDimExpr. std::string dimsStr; llvm::raw_string_ostream ss(dimsStr); llvm::interleaveComma( @@ -1552,10 +1576,14 @@ [&](std::pair p) { ss << p.second; }); ss.flush(); + // 3. Print a comma-separated list of AffineMap constructors that use the + // identifiers from 1. The AffineExpr use the common arithmetic operators on + // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder + // in return `SmallVector{{ {2} };`. std::string mapsStr; llvm::raw_string_ostream mapsStringStream(mapsStr); SmallVector orderedUses(state.orderedTensorArgs.size()); - for (auto it : state.orderedTensorArgs) + for (const auto &it : state.orderedTensorArgs) orderedUses[it.second] = it.first; llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) { assert(u.indexingMap); @@ -1576,6 +1604,7 @@ }); mapsStringStream.flush(); + // 4. Apply format to 1. using 2. and 3. os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr); }