diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -641,6 +641,7 @@ scf::ForOp newForOp = rewriter.create( forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(), newIterArgs); + newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.region().front(); // Replace the null placeholders with newly constructed values. @@ -770,6 +771,7 @@ scf::ForOp newForOp = rewriter.create( forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(), newIterOperands); + newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.region().front(); SmallVector newBlockTransferArgs(newBlock.getArguments().begin(), newBlock.getArguments().end()); diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -330,7 +330,7 @@ func private @make_i32() -> i32 -func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 { +func @for_yields_1(%lb : index, %ub : index, %step : index) -> i32 { %a = call @make_i32() : () -> (i32) %b = scf.for %i = %lb to %ub step %step iter_args(%0 = %a) -> i32 { scf.yield %0 : i32 @@ -338,10 +338,26 @@ return %b : i32 } -// CHECK-LABEL: func @for_yields_2 +// CHECK-LABEL: func @for_yields_1 // CHECK-NEXT: %[[R:.*]] = call @make_i32() : () -> i32 // CHECK-NEXT: return %[[R]] : i32 +func @for_yields_1_attr(%lb : index, %ub : index, %step : index) -> i32 { + %a = call @make_i32() : () -> (i32) + %b = scf.for %i = %lb to %ub step %step iter_args(%0 = %a) -> i32 { + %c = call @make_i32() : () -> (i32) + scf.yield %0 : i32 + } {someAttr = "someVal"} + return %b : i32 +} + +// CHECK-LABEL: func @for_yields_1_attr +// CHECK-NEXT: %[[a:.*]] = call @make_i32() : () -> i32 +// CHECK-NEXT: scf.for {{.*}} { +// CHECK-NEXT: %[[c:.*]] = call @make_i32() : () -> i32 +// CHECK-NEXT: } {someAttr = "someVal"} +// CHECK-NEXT: return %[[a]] : i32 + func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) { %a = call @make_i32() : () -> (i32) %b = call @make_i32() : () -> (i32) @@ -653,11 +669,12 @@ // CHECK: %[[DONE:.*]] = call @do(%[[CAST]]) : (tensor) -> tensor // CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor to tensor<32x1024xf32> // CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32> +// CHECK: {someAttr = "someVal"} %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor) { %2 = call @do(%iter_t0) : (tensor) -> tensor scf.yield %2 : tensor - } + } {someAttr = "someVal"} // CHECK-NOT: tensor.cast // CHECK: %[[RES:.*]] = tensor.insert_slice %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32> // CHECK: return %[[RES]] : tensor<1024x1024xf32>