diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -948,6 +948,16 @@ auto vecTy = getVectorType(scalarTy, state.strategy); auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue()); + + OpBuilder::InsertionGuard guard(state.builder); + Operation *parentOp = state.builder.getInsertionBlock()->getParentOp(); + // Find the innermost vectorized ancestor loop to insert the vector constant. + while (parentOp && !state.vecLoopToVecDim.count(parentOp)) + parentOp = parentOp->getParentOp(); + assert(parentOp && state.vecLoopToVecDim.count(parentOp) && + isa(parentOp) && "Expected a vectorized for op"); + auto vecForOp = cast(parentOp); + state.builder.setInsertionPointToStart(vecForOp.getBody()); auto newConstOp = state.builder.create(constOp.getLoc(), vecAttr); // Register vector replacement for future uses in the scope. diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir @@ -113,12 +113,12 @@ } affine.for %i4 = 0 to %M { affine.for %i5 = 0 to %N { + // CHECK: %[[SPLAT2:.*]] = constant dense<2.000000e+00> : vector<128xf32> + // CHECK: %[[SPLAT1:.*]] = constant dense<1.000000e+00> : vector<128xf32> // CHECK: %[[A5:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{[a-zA-Z0-9_]*}} : memref, vector<128xf32> // CHECK: %[[B5:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{[a-zA-Z0-9_]*}} : memref, vector<128xf32> // CHECK: %[[S5:.*]] = addf %[[A5]], %[[B5]] : vector<128xf32> - // CHECK: %[[SPLAT1:.*]] = constant dense<1.000000e+00> : vector<128xf32> // CHECK: %[[S6:.*]] = addf %[[S5]], %[[SPLAT1]] : vector<128xf32> - // CHECK: %[[SPLAT2:.*]] = constant dense<2.000000e+00> : vector<128xf32> // CHECK: %[[S7:.*]] = addf %[[S5]], %[[SPLAT2]] : vector<128xf32> // CHECK: %[[S8:.*]] = addf %[[S7]], %[[S6]] : vector<128xf32> // CHECK: vector.transfer_write %[[S8]], {{.*}} : vector<128xf32>, memref @@ -142,6 +142,29 @@ // ----- +// CHECK-LABEL: func @vec_constant_with_two_users +func @vec_constant_with_two_users(%M : index, %N : index) -> (f32, f32) { + %A = memref.alloc (%M, %N) : memref + %B = memref.alloc (%M) : memref + %f1 = constant 1.0 : f32 + affine.for %i0 = 0 to %M { // vectorized + // CHECK: %[[C1:.*]] = constant dense<1.000000e+00> : vector<128xf32> + // CHECK-NEXT: affine.for + // CHECK-NEXT: vector.transfer_write %[[C1]], {{.*}} : vector<128xf32>, memref + affine.for %i1 = 0 to %N { + affine.store %f1, %A[%i1, %i0] : memref + } + // CHECK: vector.transfer_write %[[C1]], {{.*}} : vector<128xf32>, memref + affine.store %f1, %B[%i0] : memref + } + %c12 = constant 12 : index + %res1 = affine.load %A[%c12, %c12] : memref + %res2 = affine.load %B[%c12] : memref + return %res1, %res2 : f32, f32 +} + +// ----- + // CHECK-LABEL: func @vec_rejected_1 func @vec_rejected_1(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -551,8 +574,8 @@ // CHECK-LABEL: @vec_non_vecdim_reductions // CHECK: affine.for %{{.*}} = 0 to 256 step 128 { -// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> // CHECK: %[[vone:.*]] = constant dense<1> : vector<128xi32> +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> // CHECK: %[[reds:.*]]:2 = affine.for %{{.*}} = 0 to 128 // CHECK-SAME: iter_args(%[[red_iter0:.*]] = %[[vzero]], %[[red_iter1:.*]] = %[[vone]]) -> (vector<128xf32>, vector<128xi32>) { // CHECK: %[[ld0:.*]] = vector.transfer_read %{{.*}} : memref<128x256xf32>, vector<128xf32> diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_2d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_2d.mlir --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_2d.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_2d.mlir @@ -70,12 +70,12 @@ } affine.for %i4 = 0 to %M { affine.for %i5 = 0 to %N { + // CHECK: [[SPLAT2:%.*]] = constant dense<2.000000e+00> : vector<32x256xf32> + // CHECK: [[SPLAT1:%.*]] = constant dense<1.000000e+00> : vector<32x256xf32> // CHECK: [[A5:%.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} : memref, vector<32x256xf32> // CHECK: [[B5:%.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} : memref, vector<32x256xf32> // CHECK: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<32x256xf32> - // CHECK: [[SPLAT1:%.*]] = constant dense<1.000000e+00> : vector<32x256xf32> // CHECK: [[S6:%.*]] = addf [[S5]], [[SPLAT1]] : vector<32x256xf32> - // CHECK: [[SPLAT2:%.*]] = constant dense<2.000000e+00> : vector<32x256xf32> // CHECK: [[S7:%.*]] = addf [[S5]], [[SPLAT2]] : vector<32x256xf32> // CHECK: [[S8:%.*]] = addf [[S7]], [[S6]] : vector<32x256xf32> // CHECK: vector.transfer_write [[S8]], {{.*}} : vector<32x256xf32>, memref