diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -174,7 +174,8 @@ def FillOp : LinalgStructured_Op<"fill", []> { let arguments = (ins AnyShaped:$output, - AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); + AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger, + AnyVector]>:$value); let results = (outs Optional:$result); let regions = (region AnyRegion:$region); let extraClassDeclaration = structuredOpsDecls # [{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -12,6 +12,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" @@ -52,7 +53,14 @@ OperationFolder *folder, Optional alignment = None) { auto *ctx = size.getContext(); - auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); + auto width = 1; + // Complex types contain 2 IntOrFloat elements, so scale the final width by 2 + // and get the internal element type. + if (auto type = elementType.dyn_cast()) { + elementType = type.getElementType(); + width = 2; + } + width *= llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); IntegerAttr alignment_attr; if (alignment.hasValue()) alignment_attr = @@ -274,6 +282,15 @@ else if (auto t = subView.getType().getElementType().dyn_cast()) fillVal = folded_std_constant_int(folder, 0, t); + else if (auto t = + subView.getType().getElementType().dyn_cast()) { + if (auto et = t.getElementType().dyn_cast()) + fillVal = folded_std_constant(folder, FloatAttr::get(et, 0.0)); + else if (auto et = t.getElementType().dyn_cast()) + fillVal = folded_std_constant_int(folder, 0, et); + fillVal = b.create(loc, t, fillVal, fillVal); + } else + return {}; linalg_fill(promotionInfo->fullLocalView, fillVal); } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -345,6 +345,29 @@ // CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref // CHECK: linalg.fill(%[[v0]], %[[cf]]) : memref, f32 +func @aligned_promote_fill_complex(%arg0: memref, offset: ?, strides: [?, 1]>) { + %c2000 = constant 2000 : index + %c4000 = constant 4000 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %cf = constant 1.0 : f32 + %cc = complex.create %cf, %cf : complex + %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : + memref, offset: ?, strides: [?, 1]> to memref, offset: ?, strides: [?, ?]> + linalg.fill(%3, %cc) { __internal_linalg_transform__ = "_promote_views_aligned_"} + : memref, offset: ?, strides: [?, ?]>, complex + return +} +// CHECK-LABEL: func @aligned_promote_fill_complex +// CHECK: %[[cc:.*]] = complex.create {{.*}} : complex +// CHECK: %[[s0:.*]] = memref.subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref, #map{{.*}}> to memref, #map{{.*}}> +// CHECK: %[[a0:.*]] = memref.alloc({{%.*}}) {alignment = 32 : i64} : memref +// CHECK: %[[v0:.*]] = memref.view %[[a0]][{{.*}}][{{%.*}}, {{%.*}}] : memref to memref> +// CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref> to memref, #[[$STRIDED_2D_u_1]]> +// CHECK: linalg.fill(%[[v0]], {{%.*}}) : memref>, complex +// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, #map{{.*}}>, memref, #map{{.*}}> +// CHECK: linalg.fill(%[[v0]], %[[cc]]) : memref>, complex + func @tile_permute_parallel_loop(%arg0: memref, %arg1: memref, %arg2: memref) {