diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -864,7 +864,7 @@ packingMetadata.reassociations); Value paddingValue = packOp.getPaddingValue(); if (!paddingValue) { - rewriter.create( + paddingValue = rewriter.create( loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); } auto padOp = diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -27,6 +27,37 @@ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) } +// ----- + + // CHECK-LABEL: func.func @pack( +func.func @pack(%arg0: tensor<128x8xf32>, %arg1: tensor<8x8x16x1xf32>) -> tensor<8x8x16x1xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]]] + // CHECK: : tensor<128x8xf32> to tensor<128x8xf32> + // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]] + // CHECK-SAME: : tensor<128x8xf32> into tensor<8x16x8x1xf32> + // CHECK: linalg.transpose + // CHECK-SAME: ins(%{{.*}} : tensor<8x16x8x1xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<8x8x16x1xf32>) + // CHECK-SAME: permutation = [0, 2, 1, 3] + + %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %arg1 + : tensor<128x8xf32> -> tensor<8x8x16x1xf32> + + return %pack : tensor<8x8x16x1xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) +} + // ----- // CHECK-LABEL: func.func @unpack(