diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -22,14 +22,19 @@ // The Linalg `NInputs` trait provides the API for ops that are known // to have a specified number of inputs, all passed as operands. // See Linalg/LinalgTraits.h for implementation details and usage. -class NInputs : - NativeOpTrait<"linalg::NInputs<" # !cast(args_in) # ">::Impl"> {} +class NInputs : + NativeOpTrait<"linalg::NInputs<" # !cast(n) # ">::Impl"> {} + +// The Linalg `ZeroInitTensors` trait provides the API for ops that are known +// to not have input tensor operands. +// See Linalg/LinalgTraits.h for implementation details and usage. +def ZeroInitTensors : NativeOpTrait<"linalg::ZeroInitTensors"> {} // The Linalg `NOutputs` trait provides the API for ops that are known // to have a specified number of outputs, all passed as operands. // See Linalg/LinalgTraits.h for implementation details and usage. -class NOutputs : - NativeOpTrait<"linalg::NOutputs<" # !cast(args_out) # ">::Impl"> {} +class NOutputs : + NativeOpTrait<"linalg::NOutputs<" # !cast(n) # ">::Impl"> {} def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">; def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">; @@ -62,6 +67,7 @@ def CopyOp : LinalgStructured_Op<"copy", [ CopyOpInterface, NInputs<1>, + ZeroInitTensors, NOutputs<1> ]> { let description = [{ @@ -159,7 +165,10 @@ let hasCanonicalizer = 1; } -def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { +def FillOp : LinalgStructured_Op<"fill", [ + NInputs<0>, + ZeroInitTensors, + NOutputs<1>]> { let arguments = (ins AnyStridedMemRef:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); @@ -254,7 +263,12 @@ }]; } -def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { +def ConvOp : PoolingBase_Op<"conv", [ + NInputs<2>, + // Despite having reductions, this manually defined ConvOp may only take + // memref operands and can never have init tensors. + ZeroInitTensors, + NOutputs<1>]> { let description = [{ Generic n-D convolution as described in the TF documentation: @@ -371,7 +385,12 @@ } class SingleInputPoolingBase_Op - : PoolingBase_Op, NOutputs<1>]> { + : PoolingBase_Op, + // Despite having reductions, this manually defined ConvOp may only take + // memref operands and can never have init tensors. + ZeroInitTensors, + NOutputs<1>]> { let description = [{ A base class for single input pooling function. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -125,13 +125,12 @@ getNumIterators(getReductionIteratorTypeName(), iters) == 1; }]>, //===------------------------------------------------------------------===// - // Num input/output arguments handling. + // Num input/output/initTensors arguments handling. //===------------------------------------------------------------------===// // These special methods must be defined by each op that wants to implement // the LinalgStructuredInterface. For now, this is either: - // - inherited statically by using the NInputs or - // NOutputs traits. - // - derived from args_in/args_out attributes (for linalg.generic and + // - Explicitly specified in the op definition. + // - Derived from variadic attributes (for "named" ops, linalg.generic and // linalg.indexed_generic ops). InterfaceMethod< /*desc=*/[{ @@ -140,6 +139,13 @@ /*retTy=*/"unsigned", /*methodName=*/"getNumInputs" >, + InterfaceMethod< + /*desc=*/[{ + Return the number of init tensors. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumInitTensors" + >, InterfaceMethod< /*desc=*/[{ Return the number of outputs. @@ -371,6 +377,46 @@ return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the range over init tensors. + }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getInitTensors", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = this->getOperation()->getOperands(); + return {range.begin() + getNumInputsAndOutputBuffers(), + range.begin() + getNumInputsAndOutputs()}; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return one single init tensor at position `$i`. + }], + /*retTy=*/"Value", + /*methodName=*/"getInitTensor", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < $_op.getNumInitTensors() && "overflowing init tensor index"); + return getInitTensors()[i]; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the range over inputs, output buffers and init tensors. + }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getShapedOperands", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = this->getOperation()->getOperands(); + return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + }] + >, InterfaceMethod< /*desc=*/[{ Return the `i`-th shaped type, there are 3 cases: @@ -445,7 +491,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange()); + return llvm::to_vector<4>( + $_op.indexing_maps().template getAsValueRange()); }] >, InterfaceMethod< @@ -528,11 +575,11 @@ }], /*retTy=*/"Operation *", /*methodName=*/"create", - (ins "OpBuilder &":$builder, "Location":$loc, + (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes, "ValueRange":$operands, "ArrayRef":$attributes), [{ - return builder.create(loc, TypeRange{}, operands, - attributes); + return builder.create( + loc, resultTypes, operands, attributes); }] >, InterfaceMethod< @@ -542,10 +589,12 @@ }], /*retTy=*/"Operation *", /*methodName=*/"clone", - (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{ + (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands), + [{ BlockAndValueMapping map; unsigned numRegions = $_op.getOperation()->getNumRegions(); - Operation *res = create(b, loc, operands, $_op.getAttrs()); + Operation *res = create(b, loc, resultTypes, operands, $_op.getAttrs()); assert(res->getNumRegions() == numRegions && "inconsistent # regions"); for (unsigned ridx = 0; ridx < numRegions; ++ridx) $_op.getOperation()->getRegion(ridx).cloneInto( diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -35,6 +35,17 @@ }; }; +/// This class provides the API for ops that are known to not have init tensor +/// operands. Use as a trait as follows: +/// +/// class CopyOp : public Op { +/// +template +class ZeroInitTensors : public TraitBase { +public: + static unsigned getNumInitTensors() { return 0; } +}; + /// This class provides the API for ops that are known to have a specified /// number of outputs, all passed as operands. Use as a trait as follows: /// @@ -87,6 +98,9 @@ unsigned getNumInputs() { return cast(this->getOperation()).inputs().size(); } + unsigned getNumInitTensors() { + return cast(this->getOperation()).init_tensors().size(); + } unsigned getNumOutputs() { ConcreteType concreteOp = cast(this->getOperation()); return concreteOp.output_buffers().size() + diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -99,7 +99,7 @@ auto operands = getAssumedNonViewOperands(op); clonedViews.append(operands.begin(), operands.end()); - Operation *clonedOp = op.clone(b, loc, clonedViews); + Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews); // When the producer is an IndexedGenercOp, we have to transform its block // IV arguments according to the tiling of the consumer, i.e. offset them by // the values computed in `loopRanges`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -405,7 +405,7 @@ tileSizes, allViewSizes); auto operands = getAssumedNonViewOperands(op); views.append(operands.begin(), operands.end()); - res = op.clone(b, loc, views); + res = op.clone(b, loc, /*resultTypes*/ {}, views); return scf::ValueVector{}; }, options.distribution);