diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" namespace mlir { #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSION @@ -72,28 +73,34 @@ auto collapsedInit = rewriter.create( loc, newInitTy, init, collapsedInitDims); - Value newConv; - if (isa(operation)) { - newConv = rewriter - .create( - loc, newInitTy, ValueRange{input, collapsedKernel}, - ValueRange{collapsedInit}, stride, dilation) - .getResult(0); - } else if (isa(operation)) { - newConv = - rewriter - .create( - loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, - ValueRange{collapsedInit}, stride, dilation) - .getResult(0); + Operation *newConv; + llvm::StringSet<> elidedAttrs; + if (auto op = dyn_cast(operation)) { + newConv = rewriter.create( + loc, newInitTy, ValueRange{input, collapsedKernel}, + ValueRange{collapsedInit}, stride, dilation); + elidedAttrs.insert(op.getAttributeNames().begin(), + op.getAttributeNames().end()); + } else if (auto op = dyn_cast(operation)) { + newConv = rewriter.create( + loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, + ValueRange{collapsedInit}, stride, dilation); + elidedAttrs.insert(op.getAttributeNames().begin(), + op.getAttributeNames().end()); + } else { + return failure(); } - if (!newConv) - return failure(); + elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName); + for (auto attr : operation->getAttrs()) { + if (elidedAttrs.count(attr.getName())) + continue; + newConv->setAttr(attr.getName(), attr.getValue()); + } // Expand dimensions back out to rewriter.replaceOpWithNewOp( - operation, resultTy, newConv, collapsedInitDims); + operation, resultTy, newConv->getResult(0), collapsedInitDims); return success(); } diff --git a/mlir/test/Dialect/Linalg/namedop_conversion.mlir b/mlir/test/Dialect/Linalg/namedop_conversion.mlir --- a/mlir/test/Dialect/Linalg/namedop_conversion.mlir +++ b/mlir/test/Dialect/Linalg/namedop_conversion.mlir @@ -4,9 +4,9 @@ func.func @depthwise_conv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] - // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] - %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor + %0 = linalg.depthwise_conv_2d_nhwc_hwcm {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } @@ -17,8 +17,8 @@ func.func @depthwise_conv_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { // CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] // CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] - // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) // CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] - %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor + %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {_someattr, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor return %0 : tensor }