diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -351,10 +351,11 @@ class NamedStructuredOpTraits : public OpTrait::TraitBase { public: - llvm::Optional> referenceIterators(); - llvm::Optional> referenceIndexingMaps(); - std::function)> - emitScalarImplementation(); + static SmallVector referenceIterators(TypeRange inputTypes, + TypeRange outputTypes); + + static SmallVector referenceIndexingMaps(TypeRange inputTypes, + TypeRange outputTypes); }; } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -33,10 +33,9 @@ /// Forward declarations. template -static void buildNamedStructuredOpRegion(Builder &builder, - OperationState &result, - TypeRange operandTypes, - TypeRange tensorResultTypes); +static void buildNamedStructuredOpRegionAndAttributes( + Builder &builder, OperationState &result, TypeRange operandTypes, + TypeRange tensorResultTypes); template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); template @@ -1085,9 +1084,10 @@ //===----------------------------------------------------------------------===// template -void buildNamedStructuredOpRegion(Builder &builder, OperationState &result, - TypeRange operandTypes, - TypeRange tensorResultTypes) { +void buildNamedStructuredOpRegionAndAttributes(Builder &builder, + OperationState &result, + TypeRange operandTypes, + TypeRange tensorResultTypes) { Region ®ion = *result.addRegion(); Block *body = new Block(); // TODO: atm all operands go through getElementTypeOrSelf, @@ -1102,12 +1102,24 @@ opBuilder.setInsertionPointToStart(®ion.front()); mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc()); NamedStructuredOpType::regionBuilder(*body); + + auto indexingMaps = builder.getAffineMapArrayAttr( + NamedStructuredOpType::referenceIndexingMaps(operandTypes, + tensorResultTypes)); + result.addAttribute(getIndexingMapsAttrName(), indexingMaps); + + auto iterators = + builder.getStrArrayAttr(NamedStructuredOpType::referenceIterators( + operandTypes, tensorResultTypes)); + result.addAttribute(getIteratorTypesAttrName(), iterators); } template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + std::array silentAttrNames{getIndexingMapsAttrName(), + getIteratorTypesAttrName()}; p << op.getOperationName() << ' '; - p.printOptionalAttrDict(op.getAttrs()); + p.printOptionalAttrDict(op.getAttrs(), silentAttrNames); p << ' ' << op.getOperands(); p << ": (" << op.getOperandTypes() << ")"; auto outputTensorTypes = op.getResultTypes(); @@ -1139,7 +1151,7 @@ if (!tensorResultTypes.empty()) result.addTypes(tensorResultTypes); - buildNamedStructuredOpRegion( + buildNamedStructuredOpRegionAndAttributes( parser.getBuilder(), result, operandTypes, tensorResultTypes); return parser.resolveOperands(operandsInfo, operandTypes, diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -78,11 +78,10 @@ return res; } -template -static void -inlineRegionAndEmitStdStore(OpType op, ArrayRef indexedValues, - ArrayRef> indexing, - ArrayRef outputBuffers) { +template +static void inlineRegionAndEmitStore(OpType op, ArrayRef indexedValues, + ArrayRef> indexing, + ArrayRef outputBuffers) { auto &b = ScopedContext::getBuilder(); auto &block = op.region().front(); BlockAndValueMapping map; @@ -95,10 +94,10 @@ Operation &terminator = block.back(); assert(isa(terminator) && - "expected an yield op in the end of the region"); + "expected a yield op in the end of the region"); for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { - std_store(map.lookupOrDefault(terminator.getOperand(i)), outputBuffers[i], - ArrayRef{indexing[i].begin(), indexing[i].end()}); + IndexedValueType O(outputBuffers[i]); + O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i)); } } @@ -123,9 +122,36 @@ namespace { -// Generic loop emitter, to be specialized on an op-per op basis. -// TODO: Hook up to named ops interface and, later, retire when all named ops -// are auto-generated. +/// Emits the MLIR for the scalar part of the generic op by: +/// 1. Emitting std_load and std_store ops for each input and output +/// view in order. This is achieved by applying the appropriate input or +/// output map to the enclosing induction variables. +/// 2. Emitting a call to `op.fun()` that takes as arguments the scalars +/// from point 1. above. +/// 3. Emitting std_store to store the results of 2. to the output +/// views. +// +/// An example output may resemble: +// +/// ``` +/// loop.for %i = %c0 to %0 step %c1 { +/// loop.for %j = %c0 to %1 step %c1 { +/// loop.for %k = %c0 to %4 step %c1 { +/// %11 = load %arg0[%i, %j] : +/// memref +/// %12 = load %arg1[%i, %j, %k] : +/// memref +/// %13 = load %arg2[%i, %k, %j] : +/// memref +/// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) +/// store %14#0, %arg1[%i, %j, %k] : +/// memref +/// store %14#1, %arg2[%i, %k, %j] : +/// memref +/// } +/// } +/// } +/// ``` template class LinalgScopedEmitter { public: @@ -133,9 +159,43 @@ LinalgOpType linalgOp) { assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); - llvm_unreachable("NYI"); - linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(), - ScopedContext::getLocation(), allIvs); + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + unsigned nInputs = linalgOp.getNumInputs(); + unsigned nOutputs = linalgOp.getNumOutputs(); + SmallVector indexedValues(nInputs + nOutputs); + + // 1.a. Emit std_load from input views. + for (unsigned i = 0; i < nInputs; ++i) { + auto indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getInputIndexingMap(i), allIvs); + // Pass input[i] through IndexedValueType emits the proper load operation. + indexedValues[i] = IndexedValueType(linalgOp.getInput(i))(indexing); + } + + // 1.b. Emit std_load from output views. + // TODO(mravishankar): Avoid the loads if the corresponding argument of the + // region has no uses. + for (unsigned i = 0; i < nOutputs; ++i) { + auto indexing = makeCanonicalAffineApplies( + b, loc, linalgOp.getOutputIndexingMap(i), allIvs); + // Pass input[i] through IndexedValueType emits the proper load operation. + indexedValues[nInputs + i] = + IndexedValueType(linalgOp.getOutputBuffer(i))(indexing); + } + + // TODO(ntv): When a region inliner exists, use it. + // 2. Inline region, currently only works for a single basic block. + // 3. Emit std_store. + SmallVector, 8> indexing; + SmallVector outputBuffers; + for (unsigned i = 0; i < nOutputs; ++i) { + indexing.push_back(makeCanonicalAffineApplies( + b, loc, linalgOp.getOutputIndexingMap(i), allIvs)); + outputBuffers.push_back(linalgOp.getOutputBuffer(i)); + } + inlineRegionAndEmitStore(linalgOp, indexedValues, + indexing, outputBuffers); } }; @@ -231,7 +291,7 @@ public: /// Returns the input value of convOp. If the indices in `imIdx` is out of /// boundary, returns 0 instead. - static Value getConvOpInput(ConvOp convOp, IndexedValueType im, + static Value getConvOpInput(ConvOp convOp, StdIndexedValue im, MutableArrayRef imIdx) { // TODO(ntv): add a level of indirection to linalg.generic. if (!convOp.padding()) @@ -293,7 +353,10 @@ makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); SmallVector oIdx( makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); - IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output()); + + // Padded conv involves a affine.max in the memory acces which is not + // allowed by affine.load. Override to always use an StdIndecexValue. + StdIndexedValue F(convOp.filter()), I(convOp.input()), O(convOp.output()); // Emit scalar form. Value paddedInput = getConvOpInput(convOp, I, imIdx); @@ -344,111 +407,36 @@ } }; -// Emits the MLIR for the scalar part of the generic op by: -// 1. Emitting std_load and std_store ops for each input and output -// view in order. This is achieved by applying the appropriate input or -// output map to the enclosing induction variables. -// 2. Emitting a call to `op.fun()` that takes as arguments the scalars -// from point 1. above. -// 3. Emitting std_store to store the results of 2. to the output -// views. -// -// An example output may resemble: -// -// ``` -// loop.for %i = %c0 to %0 step %c1 { -// loop.for %j = %c0 to %1 step %c1 { -// loop.for %k = %c0 to %4 step %c1 { -// %11 = load %arg0[%i, %j] : -// memref -// %12 = load %arg1[%i, %j, %k] : -// memref -// %13 = load %arg2[%i, %k, %j] : -// memref -// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) -// store %14#0, %arg1[%i, %j, %k] : -// memref -// store %14#1, %arg2[%i, %k, %j] : -// memref -// } -// } -// } -// ``` -template -class LinalgScopedEmitter { -public: - static void emitScalarImplementation(ArrayRef allIvs, - GenericOp genericOp) { - assert(genericOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto b = ScopedContext::getBuilder(); - auto loc = ScopedContext::getLocation(); - unsigned nInputs = genericOp.getNumInputs(); - unsigned nOutputs = genericOp.getNumOutputs(); - SmallVector indexedValues(nInputs + nOutputs); - - // 1.a. Emit std_load from input views. - for (unsigned i = 0; i < nInputs; ++i) { - auto indexing = makeCanonicalAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs); - indexedValues[i] = std_load(genericOp.getInput(i), indexing); - } - - // 1.b. Emit std_load from output views. - // TODO(mravishankar): Avoid the loads if the corresponding argument of the - // region has no uses. - for (unsigned i = 0; i < nOutputs; ++i) { - Value output = genericOp.getOutputBuffer(i); - auto indexing = makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs); - indexedValues[nInputs + i] = std_load(output, indexing); - } - - // TODO(ntv): When a region inliner exists, use it. - // 2. Inline region, currently only works for a single basic block. - // 3. Emit std_store. - SmallVector, 8> indexing; - SmallVector outputBuffers; - for (unsigned i = 0; i < nOutputs; ++i) { - indexing.push_back(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - outputBuffers.push_back(genericOp.getOutputBuffer(i)); - } - inlineRegionAndEmitStdStore(genericOp, indexedValues, indexing, - outputBuffers); - } -}; - -// Emits the MLIR for the scalar part of the indexed generic op by: -// 1. Emitting std_load and std_store ops for each input and output view in -// order. This is achieved by applying the appropriate input or output map -// to the enclosing induction variables. -// 2. Emitting a call to `op.fun()` that takes as arguments the induction -// variables and the scalars from point 1. above. -// 3. Emitting std_store to store the results of 2. to the output views. +/// Emits the MLIR for the scalar part of the indexed generic op by: +/// 1. Emitting std_load and std_store ops for each input and output view in +/// order. This is achieved by applying the appropriate input or output map +/// to the enclosing induction variables. +/// 2. Emitting a call to `op.fun()` that takes as arguments the induction +/// variables and the scalars from point 1. above. +/// 3. Emitting std_store to store the results of 2. to the output views. // -// An example output may resemble: +/// An example output may resemble: // -// ``` -// loop.for %i = %c0 to %0 step %c1 { -// loop.for %j = %c0 to %1 step %c1 { -// loop.for %k = %c0 to %4 step %c1 { -// %11 = load %arg0[%i, %j] : -// memref -// %12 = load %arg1[%i, %j, %k] : -// memref -// %13 = load %arg2[%i, %k, %j] : -// memref -// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : -// (index, index, index, f32, f32, f32) -> (f32, f32) -// store %14#0, %arg1[%i, %j, %k] : -// memref -// store %14#1, %arg2[%i, %k, %j] : -// memref -// } -// } -// } -// ``` +/// ``` +/// loop.for %i = %c0 to %0 step %c1 { +/// loop.for %j = %c0 to %1 step %c1 { +/// loop.for %k = %c0 to %4 step %c1 { +/// %11 = load %arg0[%i, %j] : +/// memref +/// %12 = load %arg1[%i, %j, %k] : +/// memref +/// %13 = load %arg2[%i, %k, %j] : +/// memref +/// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : +/// (index, index, index, f32, f32, f32) -> (f32, f32) +/// store %14#0, %arg1[%i, %j, %k] : +/// memref +/// store %14#1, %arg2[%i, %k, %j] : +/// memref +/// } +/// } +/// } +/// ``` template class LinalgScopedEmitter { public: @@ -493,19 +481,19 @@ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); } - inlineRegionAndEmitStdStore(indexedGenericOp, indexedValues, indexing, - outputBuffers); + inlineRegionAndEmitStore(indexedGenericOp, indexedValues, + indexing, outputBuffers); } }; -// This struct is for factoring out the implementation and support template -// instantiations in the following 2 cases: -// 1. Appending to a list of patterns via RewritePatternList. -// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. -// The implementation must work both in DRR and inside a RewritePattern. As a -// consequence, (1) it is only allowed to emit new ops if the match is -// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an -// encompassing pattern must take care of the erasure logic. +/// This struct is for factoring out the implementation and support template +/// instantiations in the following 2 cases: +/// 1. Appending to a list of patterns via RewritePatternList. +/// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. +/// The implementation must work both in DRR and inside a RewritePattern. As a +/// consequence, (1) it is only allowed to emit new ops if the match is +/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an +/// encompassing pattern must take care of the erasure logic. template class LinalgOpToLoopsImpl { public: @@ -664,7 +652,7 @@ } }; -// Helper classes for type list expansion. +/// Helper classes for type list expansion. template class RewritePatternList; @@ -692,16 +680,16 @@ >::build(patterns, ctx); } -// Local folding pattern for AffineApplyOp that we can apply greedily. -// This replaces AffineApplyOp by the proper value in cases where the associated -// map is trivial. A trivial map here is defined as a map with a single result -// and either: -// 1. Zero operand + returns a single AffineConstantExpr -// 2. One operand + returns a single AffineDimExpr -// 3. One operands + returns a single AffineSymbolExpr +/// Local folding pattern for AffineApplyOp that we can apply greedily. +/// This replaces AffineApplyOp by the proper value in cases where the +/// associated map is trivial. +/// A trivial map here is defined as a map with a single result and either: +/// 1. Zero operand + returns a single AffineConstantExpr +/// 2. One operand + returns a single AffineDimExpr +/// 3. One operands + returns a single AffineSymbolExpr // -// In the first case, the AffineApplyOp is replaced by a new constant. In the -// other cases, it is replaced by its unique operand. +/// In the first case, the AffineApplyOp is replaced by a new constant. In the +/// other cases, it is replaced by its unique operand. struct FoldAffineOp : public RewritePattern { FoldAffineOp(MLIRContext *context) : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1,8 +1,9 @@ // RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck --check-prefix=CHECKLOOP %s // RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s +// RUN: mlir-opt %s -convert-linalg-to-affine-loops --disable-pass-threading | FileCheck --check-prefix=CHECKAFFINE %s // Test that we can lower all the way to LLVM without crashing, don't check results here. -// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 +// RUN: mlir-opt %s -convert-linalg-to-loops --convert-linalg-to-llvm -o=/dev/null 2>&1 // CHECKLOOP-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECKLOOP-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> @@ -868,3 +869,65 @@ // CHECKPARALLEL-DAG: load %[[ARG1]][] // CHECKPARALLEL: addf // CHECKPARALLEL: store %{{.*}}, %[[ARG2]][] + +//----------------------------------------------------------------------------// +// Named ops to loops. +//----------------------------------------------------------------------------// +func @named_batch_matmul(%A: memref, %B: memref, %C: memref) { + linalg.batch_matmul %A, %B, %C : (memref, memref, memref) -> () + return +} +// CHECKLOOP-LABEL: @named_batch_matmul +// CHECKLOOP-SAME: %[[mA:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[mB:[a-zA-Z0-9]+]]: memref +// CHECKLOOP-SAME: %[[mC:[a-zA-Z0-9]+]]: memref +// CHECKLOOP: %[[B:.*]] = dim %[[mA]], 0 : memref +// CHECKLOOP: %[[M:.*]] = dim %[[mA]], 1 : memref +// CHECKLOOP: %[[K:.*]] = dim %[[mA]], 2 : memref +// CHECKLOOP: %[[N:.*]] = dim %[[mB]], 2 : memref +// CHECKLOOP: loop.for %[[b:.*]] = %{{.*}} to %[[B]] step %{{.*}} { +// CHECKLOOP: loop.for %[[m:.*]] = %{{.*}} to %[[M]] step %{{.*}} { +// CHECKLOOP: loop.for %[[n:.*]] = %{{.*}} to %[[N]] step %{{.*}} { +// CHECKLOOP: loop.for %[[k:.*]] = %{{.*}} to %[[K]] step %{{.*}} { +// CHECKLOOP-DAG: %[[va:.*]] = load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref +// CHECKLOOP-DAG: %[[vb:.*]] = load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref +// CHECKLOOP-DAG: %[[vc:.*]] = load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref +// CHECKLOOP-DAG: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKLOOP-DAG: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKLOOP: store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref + +// CHECKAFFINE-LABEL: @named_batch_matmul +// CHECKAFFINE-SAME: %[[mA:[a-zA-Z0-9]+]]: memref +// CHECKAFFINE-SAME: %[[mB:[a-zA-Z0-9]+]]: memref +// CHECKAFFINE-SAME: %[[mC:[a-zA-Z0-9]+]]: memref +// CHECKAFFINE: %[[B:.*]] = dim %[[mA]], 0 : memref +// CHECKAFFINE: %[[M:.*]] = dim %[[mA]], 1 : memref +// CHECKAFFINE: %[[K:.*]] = dim %[[mA]], 2 : memref +// CHECKAFFINE: %[[N:.*]] = dim %[[mB]], 2 : memref +// CHECKAFFINE: affine.for %[[b:.*]] = 0 to %[[B]] { +// CHECKAFFINE: affine.for %[[m:.*]] = 0 to %[[M]] { +// CHECKAFFINE: affine.for %[[n:.*]] = 0 to %[[N]] { +// CHECKAFFINE: affine.for %[[k:.*]] = 0 to %[[K]] { +// CHECKAFFINE-DAG: %[[va:.*]] = affine.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref +// CHECKAFFINE-DAG: %[[vb:.*]] = affine.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref +// CHECKAFFINE-DAG: %[[vc:.*]] = affine.load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref +// CHECKAFFINE-DAG: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKAFFINE-DAG: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKAFFINE: affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref + +// CHECKPARALLEL-LABEL: @named_batch_matmul +// CHECKPARALLEL-SAME: %[[mA:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[mB:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL-SAME: %[[mC:[a-zA-Z0-9]+]]: memref +// CHECKPARALLEL: %[[B:.*]] = dim %[[mA]], 0 : memref +// CHECKPARALLEL: %[[M:.*]] = dim %[[mA]], 1 : memref +// CHECKPARALLEL: %[[K:.*]] = dim %[[mA]], 2 : memref +// CHECKPARALLEL: %[[N:.*]] = dim %[[mB]], 2 : memref +// CHECKPARALLEL: loop.parallel (%[[b:.*]], %[[m:.*]], %[[n:.*]]) = ({{.*}}) to (%[[B]], %[[M]], %[[N]]) step ({{.*}}) { +// CHECKPARALLEL: loop.for %[[k:.*]] = %{{.*}} to %[[K]] step %{{.*}} { +// CHECKPARALLEL-DAG: %[[va:.*]] = load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref +// CHECKPARALLEL-DAG: %[[vb:.*]] = load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref +// CHECKPARALLEL-DAG: %[[vc:.*]] = load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref +// CHECKPARALLEL-DAG: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 +// CHECKPARALLEL-DAG: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -7,15 +7,15 @@ // ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: Test1Op::referenceIterators() { -// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// IMPL-LABEL: SmallVector Test1Op::referenceIterators +// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: Test1Op::referenceIndexingMaps() { +// IMPL: SmallVector Test1Op::referenceIndexingMaps // IMPL: AffineMap::get(2, 0, {d0, d1}, context), // IMPL-NEXT: AffineMap::get(2, 0, {d1}, context), // IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) }; // -// IMPL: Test1Op::regionBuilder(Block &block) { +// IMPL: void Test1Op::regionBuilder(Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); @@ -32,10 +32,10 @@ // ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: Test2Op::referenceIterators() { -// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// IMPL-LABEL: SmallVector Test2Op::referenceIterators +// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: Test2Op::referenceIndexingMaps() { +// IMPL: SmallVector Test2Op::referenceIndexingMaps // IMPL: AffineMap::get(3, 0, {d0, d2}, context), // IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}, context), // IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) }; @@ -57,10 +57,10 @@ // ODS-NEXT: NamedStructuredOpTraits // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: Test3Op::referenceIterators() { -// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// IMPL-LABEL: SmallVector Test3Op::referenceIterators +// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: Test3Op::referenceIndexingMaps() { +// IMPL: SmallVector Test3Op::referenceIndexingMaps // IMPL: AffineMap::get(4, 0, {d0, d1, d3}, context), // IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}, context), // IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) }; diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1472,7 +1472,7 @@ [{{ result.addOperands(views); result.addTypes(outputTypes); - buildNamedStructuredOpRegion<{0}>( + buildNamedStructuredOpRegionAndAttributes<{0}>( b, result, TypeRange(views), outputTypes); }]> ]; @@ -1481,7 +1481,13 @@ }]; let extraClassDeclaration = [{{ llvm::Optional> referenceIterators(); + static SmallVector referenceIterators( + TypeRange inputTypes, TypeRange outputTypes); + llvm::Optional> referenceIndexingMaps(); + static SmallVector referenceIndexingMaps( + TypeRange inputTypes, TypeRange outputTypes); + static void regionBuilder(Block &block); }]; })FMT"; @@ -1503,7 +1509,13 @@ ComprehensionParsingState &state) { const char *referenceReferenceIteratorsFmt = R"FMT( - llvm::Optional> {0}::referenceIterators() { + // This is temporary until we transition out of manually specified ops + // that should be auto-generated with linalg-ods-gen. + llvm::Optional> {0}::referenceIterators() {{ + llvm_unreachable("Unexpected missing `iterator_types` attribute."); + } + SmallVector {0}::referenceIterators( + TypeRange inputTypes, TypeRange outputTypes) { return SmallVector{{ {1} }; })FMT"; @@ -1538,8 +1550,15 @@ ComprehensionParsingState &state) { const char *referenceIndexingMapsFmt = R"FMT( - llvm::Optional> {0}::referenceIndexingMaps() { - MLIRContext *context = getContext(); + // This is temporary until we transition out of manually specified ops that + // should be auto-generated with linalg-ods-gen. + llvm::Optional> {0}::referenceIndexingMaps() {{ + llvm_unreachable("Unexpected missing `indexing_maps` attribute."); + } + SmallVector {0}::referenceIndexingMaps( + TypeRange inputTypes, TypeRange outputTypes) { + assert(!inputTypes.empty() && "At least one input expected"); + MLIRContext *context = (*inputTypes.begin()).getContext(); AffineExpr {1}; bindDims(context, {1}); return SmallVector{{ {2} }; @@ -1555,7 +1574,7 @@ std::string mapsStr; llvm::raw_string_ostream mapsStringStream(mapsStr); SmallVector orderedUses(state.orderedTensorArgs.size()); - for (auto it : state.orderedTensorArgs) + for (const auto &it : state.orderedTensorArgs) orderedUses[it.second] = it.first; llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) { assert(u.indexingMap);