diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -338,6 +338,24 @@ return 0; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the input block arguments of the region. + }], + /*retTy=*/"Block::BlockArgListType", + /*methodName=*/"getRegionInputArgs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with StructuredOpInterface must + // implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + return getBlock()->getArguments().take_front( + cast(*this->getOperation()) + .getNumInputs()); + }] + >, InterfaceMethod< /*desc=*/[{ Return the output block arguments of the region. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -19,6 +19,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpAsmInterface.td" // Base Tablegen class for Linalg ops. // Linalg ops that correspond to library calls operate on ShapedType as their @@ -229,8 +230,10 @@ AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; def ReduceOp : LinalgStructuredBase_Op<"reduce", [ - SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp"> - ]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SameVariadicOperandSize, + SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Reduce operator"; let description = [{ Executes `combiner` on the `dimensions` of `inputs` and returns the diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1187,6 +1187,19 @@ // ReduceOp //===----------------------------------------------------------------------===// +void ReduceOp::getAsmBlockArgumentNames(Region ®ion, + OpAsmSetValueNameFn setNameFn) { + for (Value v : getRegionInputArgs()) + setNameFn(v, "in"); + for (Value v : getRegionOutputArgs()) + setNameFn(v, "init"); +} + +void ReduceOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResults().front(), "reduced"); +} + ArrayAttr ReduceOp::getIteratorTypes() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); SmallVector iteratorTypes(inputRank,