diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -14,6 +14,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringSet.h" namespace mlir { class AffineExpr; @@ -495,6 +496,23 @@ ArrayRef procInfo = {}); }; +/// Returns an attribute list that excludes pre-defined attributes. +template +SmallVector getPrunedAttributeList(OpTy op) { + llvm::StringSet<> elidedAttrs; + elidedAttrs.insert(op.getAttributeNames().begin(), + op.getAttributeNames().end()); + if (isa(op.getOperation())) + elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName); + SmallVector attrs; + for (auto attr : op->getAttrs()) { + if (elidedAttrs.count(attr.getName())) + continue; + attrs.push_back(attr); + } + return attrs; +} + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" namespace mlir { #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSION @@ -72,28 +73,30 @@ auto collapsedInit = rewriter.create( loc, newInitTy, init, collapsedInitDims); - Value newConv; - if (isa(operation)) { - newConv = rewriter - .create( - loc, newInitTy, ValueRange{input, collapsedKernel}, - ValueRange{collapsedInit}, stride, dilation) - .getResult(0); - } else if (isa(operation)) { - newConv = - rewriter - .create( + SmallVector preservedAttrs; + Operation *newConv = + TypeSwitch(operation) + .Case([&](auto op) { + preservedAttrs = getPrunedAttributeList(op); + return rewriter.create( + loc, newInitTy, ValueRange{input, collapsedKernel}, + ValueRange{collapsedInit}, stride, dilation); + }) + .Case([&](auto op) { + preservedAttrs = getPrunedAttributeList(op); + return rewriter.create( loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, - ValueRange{collapsedInit}, stride, dilation) - .getResult(0); - } - + ValueRange{collapsedInit}, stride, dilation); + }) + .Default([](Operation *op) { return nullptr; }); if (!newConv) return failure(); + for (auto attr : preservedAttrs) + newConv->setAttr(attr.getName(), attr.getValue()); // Expand dimensions back out to rewriter.replaceOpWithNewOp( - operation, resultTy, newConv, collapsedInitDims); + operation, resultTy, newConv->getResult(0), collapsedInitDims); return success(); } diff --git a/mlir/test/Dialect/Linalg/namedop_conversion.mlir b/mlir/test/Dialect/Linalg/namedop_conversion.mlir --- a/mlir/test/Dialect/Linalg/namedop_conversion.mlir +++ b/mlir/test/Dialect/Linalg/namedop_conversion.mlir @@ -4,9 +4,9 @@ func.func @depthwise_conv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] - // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] - %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor + %0 = linalg.depthwise_conv_2d_nhwc_hwcm {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } @@ -17,8 +17,8 @@ func.func @depthwise_conv_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] - // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] - %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor + %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor return %0 : tensor }