diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -178,6 +178,16 @@ .getResult(0) .dyn_cast()) boundingConst = cExpr.getValue(); + } else if (auto dimOp = size.getDefiningOp()) { + auto shape = dimOp.memrefOrTensor().getType().dyn_cast(); + if (auto constOp = dimOp.index().getDefiningOp()) { + if (auto indexAttr = constOp.value().dyn_cast()) { + auto dimIndex = indexAttr.getInt(); + if (!shape.isDynamicDim(dimIndex)) { + boundingConst = shape.getShape()[dimIndex]; + } + } + } } if (boundingConst && *boundingConst >= 0) return Builder(size.getContext()).getIndexAttr(*boundingConst); diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -47,3 +47,36 @@ // CHECK-1DIM-TILE: %[[TC:[0-9a-z]+]]: tensor) -> tensor { // CHECK-1DIM-TILE-NOT: scf.for // CHECK-1DIM-TILE: linalg.matmul_i8_i8_i32 ins(%[[TA]], %[[TB]] : tensor, tensor) outs(%[[TC]] : tensor) -> tensor + +func @matmul_partially_padded_tensors( + %arg0: tensor, %arg1: tensor<8x?xi8>, %arg2: tensor) + -> tensor { + %0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"} + ins(%arg0, %arg1: tensor, tensor<8x?xi8>) + outs(%arg2: tensor) + -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @matmul_partially_padded_tensors( +// CHECK: linalg.matmul_i8_i8_i32 ins({{.*}}, {{.*}} : tensor<2x4xi8>, tensor<4x3xi8>) outs({{.*}} : tensor<2x3xi32>) -> tensor<2x3xi32> + + +// CHECK-1DIM-TILE: func @matmul_partially_padded_tensors( +// CHECK-1DIM-TILE-SAME: %[[TA:[0-9a-z]+]]: tensor +// CHECK-1DIM-TILE-SAME: %[[TB:[0-9a-z]+]]: tensor<8x?xi8> +// CHECK-1DIM-TILE-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +// CHECK-1DIM-TILE: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { +// CHECK-1DIM-TILE: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { +// CHECK-1DIM-TILE: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor +// CHECK-1DIM-TILE: %[[sTAc:.*]] = tensor.cast %[[sTA]] : tensor to tensor +// CHECK-1DIM-TILE: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8> +// CHECK-1DIM-TILE: %[[sTBc:.*]] = tensor.cast %[[sTB]] : tensor<8x?xi8> to tensor +// CHECK-1DIM-TILE: %[[sTC:.*]] = subtensor %[[TC1]][{{.*}}] : tensor to tensor +// CHECK-1DIM-TILE: %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: : tensor to tensor<2x8xi8> +// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: : tensor to tensor<8x3xi8> +// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}] +// CHECK-1DIM-TILE: : tensor to tensor<2x3xi32> +// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) +// CHECK-1DIM-TILE: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>